This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f086c2327c36 [SPARK-47174][CONNECT][SS][1/2] Server side SparkConnectListenerBusListener for Client side streaming query listener f086c2327c36 is described below commit f086c2327c36c396ae5d886afd3ef613650c6b0d Author: Wei Liu <wei....@databricks.com> AuthorDate: Fri Apr 12 10:08:45 2024 +0900 [SPARK-47174][CONNECT][SS][1/2] Server side SparkConnectListenerBusListener for Client side streaming query listener ### What changes were proposed in this pull request? Server side `SparkConnectListenerBusListener` implementation for the client side listener. There would only be one such listener for each `SessionHolder`. ### Why are the changes needed? Move streaming query listener to client side ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #45988 from WweiL/SPARK-47174-client-side-listener-1. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 34 ++- ...SparkConnectStreamingQueryListenerHandler.scala | 121 +++++++++++ .../spark/sql/connect/service/SessionHolder.scala | 7 + .../service/SparkConnectListenerBusListener.scala | 156 ++++++++++++++ .../SparkConnectListenerBusListenerSuite.scala | 240 +++++++++++++++++++++ 5 files changed, 555 insertions(+), 3 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 96db45c5c63e..5e7f3b74c299 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2551,6 +2551,11 @@ class SparkConnectPlanner( handleStreamingQueryManagerCommand( command.getStreamingQueryManagerCommand, responseObserver) + case proto.Command.CommandTypeCase.STREAMING_QUERY_LISTENER_BUS_COMMAND => + val handler = new SparkConnectStreamingQueryListenerHandler(executeHolder) + handler.handleListenerCommand( + command.getStreamingQueryListenerBusCommand, + responseObserver) case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND => handleGetResourcesCommand(responseObserver) case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND => @@ -3118,7 +3123,7 @@ class SparkConnectPlanner( } executeHolder.eventsManager.postFinished() - val result = WriteStreamOperationStartResult + val resultBuilder = WriteStreamOperationStartResult .newBuilder() .setQueryId( StreamingQueryInstanceId @@ -3127,14 +3132,37 @@ class SparkConnectPlanner( .setRunId(query.runId.toString) .build()) .setName(Option(query.name).getOrElse("")) - .build() + + // The query started event for this query is sent to the client, and is handled by + // the client side listeners before client's DataStreamWriter.start() returns. + // This is to ensure that the onQueryStarted call back is called before the start() call, which + // is defined in the onQueryStarted API. + // So the flow is: + // 1. On the server side, the query is started above. + // 2. Per the contract of the onQueryStarted API, the queryStartedEvent is added to the + // streamingServersideListenerHolder.streamingQueryStartedEventCache, by the onQueryStarted + // call back of streamingServersideListenerHolder.streamingQueryServerSideListener. + // 3. The queryStartedEvent is sent to the client. + // 4. The client side listener handles the queryStartedEvent and calls the onQueryStarted API, + // before the client side DataStreamWriter.start(). + // This way we ensure that the onQueryStarted API is called before the start() call in Connect. + val queryStartedEvent = Option( + sessionHolder.streamingServersideListenerHolder.streamingQueryStartedEventCache.remove( + query.runId.toString)) + queryStartedEvent.foreach { + logDebug( + s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}][query id: ${query.id}][query runId: ${query.runId}] " + + s"Adding QueryStartedEvent to response") + e => resultBuilder.setQueryStartedEventJson(e.json) + } responseObserver.onNext( ExecutePlanResponse .newBuilder() .setSessionId(sessionId) .setServerSideSessionId(sessionHolder.serverSessionId) - .setWriteStreamOperationStartResult(result) + .setWriteStreamOperationStartResult(resultBuilder.build()) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala new file mode 100644 index 000000000000..94f01026b7a5 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import scala.util.control.NonFatal + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto.ExecutePlanResponse +import org.apache.spark.connect.proto.StreamingQueryListenerBusCommand +import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.service.ExecuteHolder + +/** + * Handle long-running streaming query listener events. + */ +class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder) extends Logging { + + val sessionHolder = executeHolder.sessionHolder + + private[connect] def userId: String = sessionHolder.userId + + private[connect] def sessionId: String = sessionHolder.sessionId + + /** + * The handler logic. The handler of ADD_LISTENER_BUS_LISTENER uses the + * streamingQueryListenerLatch to block the handling thread, preventing it from sending back the + * final ResultComplete response. + * + * The handler of REMOVE_LISTENER_BUS_LISTENER cleans up the server side listener resources and + * count down the latch, allowing the handling thread of the original ADD_LISTENER_BUS_LISTENER + * to proceed to send back the final ResultComplete response. + */ + def handleListenerCommand( + command: StreamingQueryListenerBusCommand, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + + val listenerHolder = sessionHolder.streamingServersideListenerHolder + + command.getCommandCase match { + case StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER => + listenerHolder.isServerSideListenerRegistered match { + case true => + logWarning( + s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}] Redundant server side listener added. Exiting.") + return + case false => + // This transfers sending back the response to the client until + // the long running command is terminated, either by + // errors in streamingQueryServerSideListener.send, + // or client issues a REMOVE_LISTENER_BUS_LISTENER call. + listenerHolder.init(responseObserver) + // Send back listener added response + val respBuilder = StreamingQueryListenerEventsResult.newBuilder() + val listenerAddedResult = respBuilder + .setListenerBusListenerAdded(true) + .build() + try { + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionHolder.sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) + .setStreamingQueryListenerEventsResult(listenerAddedResult) + .build()) + } catch { + case NonFatal(e) => + logError( + s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}] Error sending listener added response.", + e) + listenerHolder.cleanUp() + return + } + } + logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}] Server side listener added. Now blocking until " + + "all client side listeners are removed or there is error transmitting the event back.") + // Block the handling thread, and have serverListener continuously send back new events + listenerHolder.streamingQueryListenerLatch.await() + logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}] Server side listener long-running handling thread ended.") + case StreamingQueryListenerBusCommand.CommandCase.REMOVE_LISTENER_BUS_LISTENER => + listenerHolder.isServerSideListenerRegistered match { + case true => + sessionHolder.streamingServersideListenerHolder.cleanUp() + case false => + logWarning( + s"[SessionId: $sessionId][UserId: $userId][operationId: " + + s"${executeHolder.operationId}] No active server side listener bus listener " + + s"but received remove listener call. Exiting.") + return + } + case StreamingQueryListenerBusCommand.CommandCase.COMMAND_NOT_SET => + throw new IllegalArgumentException("Missing command in StreamingQueryListenerBusCommand") + } + // If this thread is the handling thread of the original ADD_LISTENER_BUS_LISTENER command, + // this will be sent when the latch is counted down (either through + // a REMOVE_LISTENER_BUS_LISTENER command, or long-lived gRPC throws. + // If this thread is the handling thread of the REMOVE_LISTENER_BUS_LISTENER command, + // this is hit right away. + executeHolder.eventsManager.postFinished() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index ef79cdcce8ff..306b89148583 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -92,6 +92,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) + private[connect] lazy val streamingServersideListenerHolder = new ServerSideListenerHolder(this) + def key: SessionKey = SessionKey(userId, sessionId) // Returns the server side session ID and asserts that it must be different from the client-side @@ -267,6 +269,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. removeAllListeners() // removes all listener and stop python listener processes if necessary. + // if there is a server side listener, clean up related resources + if (streamingServersideListenerHolder.isServerSideListenerRegistered) { + streamingServersideListenerHolder.cleanUp() + } + // Clean up all executions. // After closedTimeMs is defined, SessionHolder.addExecuteHolder() will not allow new executions // to be added for this session anymore. Because both SessionHolder.addExecuteHolder() and diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala new file mode 100644 index 000000000000..1b6c5179871d --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch} + +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto.ExecutePlanResponse +import org.apache.spark.connect.proto.StreamingQueryEventType +import org.apache.spark.connect.proto.StreamingQueryListenerEvent +import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult +import org.apache.spark.internal.Logging +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.ArrayImplicits._ + +/** + * A holder for the server side listener and related resources. There should be only one such + * holder for each sessionHolder. + */ +private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { + // The server side listener that is responsible to stream streaming query events back to client. + // There is only one listener per sessionHolder, but each listener is responsible for all events + // of all streaming queries in the SparkSession. + var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None + // The count down latch to hold the long-running listener thread before sending ResultComplete. + var streamingQueryListenerLatch = new CountDownLatch(1) + // The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent. + // Events for corresponding query will be sent back to client with + // the WriteStreamOperationStart response, so that the client can handle the event before + // DataStreamWriter.start() returns. This special handling is to satisfy the contract of + // onQueryStarted in StreamingQueryListener. + val streamingQueryStartedEventCache + : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() + + def isServerSideListenerRegistered: Boolean = streamingQueryServerSideListener.isDefined + + /** + * The initialization of the server side listener and related resources. This method is called + * when the first ADD_LISTENER_BUS_LISTENER command is received. It is attached to a + * responseObserver, from the first executeThread (long running thread), so the lifecycle of the + * responseObserver is the same as the life cycle of the listener. + * + * @param responseObserver + * the responseObserver created from the first long running executeThread. + */ + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + val serverListener = new SparkConnectListenerBusListener(this, responseObserver) + sessionHolder.session.streams.addListener(serverListener) + streamingQueryServerSideListener = Some(serverListener) + streamingQueryListenerLatch = new CountDownLatch(1) + } + + /** + * The cleanup of the server side listener and related resources. This method is called when the + * REMOVE_LISTENER_BUS_LISTENER command is received or when responseObserver.onNext throws an + * exception. It removes the listener from the session, clears the cache. Also it counts down + * the latch, so the long-running thread can proceed to send back the final ResultComplete + * response. + */ + def cleanUp(): Unit = { + streamingQueryServerSideListener.foreach { listener => + sessionHolder.session.streams.removeListener(listener) + } + streamingQueryStartedEventCache.clear() + streamingQueryServerSideListener = None + streamingQueryListenerLatch.countDown() + } +} + +/** + * A customized StreamingQueryListener used in Spark Connect for the client-side listeners. Upon + * the invocation of each callback function, it serializes the event to json and sent it to the + * client. + */ +private[sql] class SparkConnectListenerBusListener( + serverSideListenerHolder: ServerSideListenerHolder, + responseObserver: StreamObserver[ExecutePlanResponse]) + extends StreamingQueryListener + with Logging { + + val sessionHolder = serverSideListenerHolder.sessionHolder + // The method used to stream back the events to the client. + // The event is serialized to json and sent to the client. + // The responseObserver is what of the first executeThread (long running thread), + // which is held still by the streamingQueryListenerLatch. + // If any exception is thrown while transmitting back the event, the listener is removed, + // all related sources are cleaned up, and the long-running thread will proceed to send + // the final ResultComplete response. + private def send(eventJson: String, eventType: StreamingQueryEventType): Unit = { + val event = StreamingQueryListenerEvent + .newBuilder() + .setEventJson(eventJson) + .setEventType(eventType) + .build() + + val respBuilder = StreamingQueryListenerEventsResult.newBuilder() + val eventResult = respBuilder + .addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava) + .build() + + try { + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionHolder.sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) + .setStreamingQueryListenerEventsResult(eventResult) + .build()) + } catch { + case NonFatal(e) => + logError( + s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " + + s"Removing SparkConnectListenerBusListener and terminating the long-running thread " + + s"because of exception: $e") + // This likely means that the client is not responsive even with retry, we should + // remove this listener and cleanup resources. + serverSideListenerHolder.cleanUp() + } + } + + // QueryStartedEvent is sent to client along with WriteStreamOperationStartResult + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + serverSideListenerHolder.streamingQueryStartedEventCache.put(event.runId.toString, event) + } + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT) + } + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT) + } + + override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { + send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala new file mode 100644 index 000000000000..4c2962fda507 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.UUID + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.DurationInt +import scala.jdk.CollectionConverters._ + +import io.grpc.stub.StreamObserver +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.connect.proto.{Command, ExecutePlanResponse} +import org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener} +import org.apache.spark.sql.streaming.Trigger.ProcessingTime +import org.apache.spark.sql.test.SharedSparkSession + +class SparkConnectListenerBusListenerSuite + extends SparkFunSuite + with SharedSparkSession + with MockitoSugar { + + override def afterEach(): Unit = { + try { + spark.streams.active.foreach(_.stop()) + spark.streams.listListeners().foreach(spark.streams.removeListener) + } finally { + super.afterEach() + } + } + + // A test listener that caches all events + private class CacheEventsStreamingQueryListener( + startEvents: ArrayBuffer[StreamingQueryListener.QueryStartedEvent], + otherEvents: ArrayBuffer[StreamingQueryListener.Event]) + extends StreamingQueryListener { + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + startEvents += event + } + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + otherEvents += event + } + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + otherEvents += event + } + + override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { + otherEvents += event + } + } + + private def verifyEventsSent( + fromCachedEventsListener: ArrayBuffer[StreamingQueryListener.Event], + fromListenerBusListener: ArrayBuffer[String]): Unit = { + assert(fromListenerBusListener.toSet === fromCachedEventsListener.map { + case e: StreamingQueryListener.QueryStartedEvent => e.json + case e: StreamingQueryListener.QueryProgressEvent => e.json + case e: StreamingQueryListener.QueryTerminatedEvent => e.json + case e: StreamingQueryListener.QueryIdleEvent => e.json + }.toSet) + } + + private def startQuery(slow: Boolean = false): StreamingQuery = { + val dsw = spark.readStream.format("rate").load().writeStream.format("noop") + if (slow) { + dsw.trigger(ProcessingTime("20 seconds")) + } + dsw.start() + } + + Seq(1, 5, 20).foreach { queryNum => + test( + "Basic functionalities - onQueryStart, onQueryProgress, onQueryTerminated" + + s" - $queryNum queries") { + val sessionHolder = SessionHolder.forTesting(spark) + val responseObserver = mock[StreamObserver[ExecutePlanResponse]] + val eventJsonBuffer = ArrayBuffer.empty[String] + val startEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent] + val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event] + + doAnswer((invocation: InvocationOnMock) => { + val argument = invocation.getArgument[ExecutePlanResponse](0) + val eventJson = argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson + eventJsonBuffer += eventJson + }).when(responseObserver).onNext(any[ExecutePlanResponse]()) + + val listenerHolder = sessionHolder.streamingServersideListenerHolder + listenerHolder.init(responseObserver) + val cachedEventsListener = + new CacheEventsStreamingQueryListener(startEventsBuffer, otherEventsBuffer) + + spark.streams.addListener(cachedEventsListener) + + for (_ <- 1 to queryNum) startQuery() + + // after all queries made some progresses + eventually(timeout(60.seconds), interval(2.seconds)) { + spark.streams.active.foreach { q => + assert(q.lastProgress.batchId > 5) + } + } + + // stops all queries + spark.streams.active.foreach(_.stop()) + + eventually(timeout(60.seconds), interval(500.milliseconds)) { + assert(eventJsonBuffer.nonEmpty) + assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty) + verifyEventsSent(otherEventsBuffer, eventJsonBuffer) + assert( + startEventsBuffer.map(_.json).toSet === + listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet) + } + } + } + + test("Basic functionalities - Slow query") { + val sessionHolder = SessionHolder.forTesting(spark) + val responseObserver = mock[StreamObserver[ExecutePlanResponse]] + val eventJsonBuffer = ArrayBuffer.empty[String] + val startEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent] + val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event] + + doAnswer((invocation: InvocationOnMock) => { + val argument = invocation.getArgument[ExecutePlanResponse](0) + val eventJson = argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson + eventJsonBuffer += eventJson + }).when(responseObserver).onNext(any[ExecutePlanResponse]()) + + val listenerHolder = sessionHolder.streamingServersideListenerHolder + listenerHolder.init(responseObserver) + + val cachedEventsListener = + new CacheEventsStreamingQueryListener(startEventsBuffer, otherEventsBuffer) + spark.streams.addListener(cachedEventsListener) + + // Slow query + val q = startQuery(true) + + // after the slow query made some progresses + eventually(timeout(100.seconds), interval(7.seconds)) { + assert(q.lastProgress.batchId > 2) + } + + q.stop() + + eventually(timeout(60.seconds), interval(1.second)) { + assert(eventJsonBuffer.nonEmpty) + assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty) + verifyEventsSent(otherEventsBuffer, eventJsonBuffer) + assert( + startEventsBuffer.map(_.json).toSet === + listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet) + } + } + + test("Proper handling on onNext throw - initial response") { + val sessionHolder = SessionHolder.forTesting(spark) + + val executeHolder = mock[ExecuteHolder] + when(executeHolder.sessionHolder).thenReturn(sessionHolder) + when(executeHolder.operationId).thenReturn("operationId") + + val responseObserver = mock[StreamObserver[ExecutePlanResponse]] + doThrow(new RuntimeException("I'm dead")) + .when(responseObserver) + .onNext(any[ExecutePlanResponse]()) + + val listenerCntBeforeThrow = spark.streams.listListeners().size + + val handler = new SparkConnectStreamingQueryListenerHandler(executeHolder) + val listenerBusCmdBuilder = Command.newBuilder().getStreamingQueryListenerBusCommandBuilder + val addListenerCommand = listenerBusCmdBuilder.setAddListenerBusListener(true).build() + handler.handleListenerCommand(addListenerCommand, responseObserver) + + val listenerHolder = sessionHolder.streamingServersideListenerHolder + eventually(timeout(5.seconds), interval(500.milliseconds)) { + assert( + sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty) + assert(spark.streams.listListeners().size === listenerCntBeforeThrow) + assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) + assert(listenerHolder.streamingQueryListenerLatch.getCount === 0) + } + + } + + test("Proper handling on onNext throw - query progress") { + val sessionHolder = SessionHolder.forTesting(spark) + val responseObserver = mock[StreamObserver[ExecutePlanResponse]] + doThrow(new RuntimeException("I'm dead")) + .when(responseObserver) + .onNext(any[ExecutePlanResponse]()) + + val listenerHolder = sessionHolder.streamingServersideListenerHolder + listenerHolder.init(responseObserver) + val listenerBusListener = listenerHolder.streamingQueryServerSideListener.get + + // mock a QueryStartedEvent for cleanup test + val queryStartedEvent = new StreamingQueryListener.QueryStartedEvent( + UUID.randomUUID, + UUID.randomUUID, + "name", + "timestamp") + listenerHolder.streamingQueryStartedEventCache.put( + queryStartedEvent.runId.toString, + queryStartedEvent) + + startQuery() + + eventually(timeout(5.seconds), interval(500.milliseconds)) { + assert(!spark.streams.listListeners().contains(listenerBusListener)) + assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) + assert(listenerHolder.streamingQueryListenerLatch.getCount === 0) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org