rangadi commented on code in PR #41318:
URL: https://github.com/apache/spark/pull/41318#discussion_r1213488092


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2597,6 +2600,50 @@ class SparkConnectPlanner(val session: SparkSession) {
         .build())
   }
 
+  /**
+   * A helper function to handle streaming awaitTermination(). 
awaitTermination() can be a long
+   * running command. In this function, we periodically check if the RPC call 
has been cancelled.
+   * If so, we can stop the operation and release resources early.
+   * @param query the query waits to be terminated
+   * @param timeoutMs optional. Timeout to wait for termination. If None, no 
timeout is set
+   * @return if the query has terminated
+   */
+  private def handleStreamingAwaitTermination(
+      query: StreamingQuery,
+      timeoutMs: Option[Long]): Boolean = {
+    // How often to check if RPC is cancelled and call awaitTermination()
+    val awaitTerminationIntervalMs = 10000
+
+    val hasTimeout = timeoutMs.isDefined
+    var timeoutLeftMs = timeoutMs.getOrElse(Long.MaxValue)
+    require(timeoutLeftMs > 0, "Timeout has to be positive")
+
+    val grpcContext = Context.current
+    while (!grpcContext.isCancelled) {
+      val awaitTimeMs = if (hasTimeout) {
+        math.min(awaitTerminationIntervalMs, timeoutLeftMs)
+      } else {
+        awaitTerminationIntervalMs
+      }
+
+      val terminated = query.awaitTermination(awaitTimeMs)
+      if (terminated) {
+        return true
+      }
+
+      if (hasTimeout) {
+        timeoutLeftMs -= awaitTerminationIntervalMs
+        if (timeoutLeftMs <= 0) {
+          return false
+        }
+      }
+    }
+
+    // gRPC is cancelled
+    logError("RPC context is cancelled when executing awaitTermination()")

Review Comment:
   Should be warning at most. Info is fine too. Log the query id. 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2576,11 +2578,12 @@ class SparkConnectPlanner(val session: SparkSession) {
 
       case StreamingQueryCommand.CommandCase.AWAIT_TERMINATION =>
         if (command.getAwaitTermination.hasTimeoutMs) {
-          val terminated = 
query.awaitTermination(command.getAwaitTermination.getTimeoutMs)
+          val terminated = handleStreamingAwaitTermination(query,
+            Some(command.getAwaitTermination.getTimeoutMs))
           respBuilder.getAwaitTerminationBuilder
             .setTerminated(terminated)
         } else {
-          query.awaitTermination()
+          handleStreamingAwaitTermination(query, None)
           respBuilder.getAwaitTerminationBuilder
             .setTerminated(true)
         }

Review Comment:
   Can we simplify this? We don't need larger conditional.  Could just be:
   
   ```scala
   val timeout = if (command.getAwaitTermination.hasTimeoutMs)
      Some(command.getAwaitTermination.getTimeoutMs) else None
   val terminated = handleStreamingAwaitTermination(query, timeout)
     respBuilder.getAwaitTerminationBuilder
     .setTerminated(terminated)
   
   



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2597,6 +2600,50 @@ class SparkConnectPlanner(val session: SparkSession) {
         .build())
   }
 
+  /**
+   * A helper function to handle streaming awaitTermination(). 
awaitTermination() can be a long
+   * running command. In this function, we periodically check if the RPC call 
has been cancelled.
+   * If so, we can stop the operation and release resources early.
+   * @param query the query waits to be terminated
+   * @param timeoutMs optional. Timeout to wait for termination. If None, no 
timeout is set
+   * @return if the query has terminated
+   */
+  private def handleStreamingAwaitTermination(
+      query: StreamingQuery,
+      timeoutMs: Option[Long]): Boolean = {
+    // How often to check if RPC is cancelled and call awaitTermination()
+    val awaitTerminationIntervalMs = 10000
+
+    val hasTimeout = timeoutMs.isDefined
+    var timeoutLeftMs = timeoutMs.getOrElse(Long.MaxValue)
+    require(timeoutLeftMs > 0, "Timeout has to be positive")
+
+    val grpcContext = Context.current
+    while (!grpcContext.isCancelled) {
+      val awaitTimeMs = if (hasTimeout) {

Review Comment:
   No need to check, timeout is already set to max value when it is not set. 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2597,6 +2600,50 @@ class SparkConnectPlanner(val session: SparkSession) {
         .build())
   }
 
+  /**
+   * A helper function to handle streaming awaitTermination(). 
awaitTermination() can be a long
+   * running command. In this function, we periodically check if the RPC call 
has been cancelled.
+   * If so, we can stop the operation and release resources early.
+   * @param query the query waits to be terminated
+   * @param timeoutMs optional. Timeout to wait for termination. If None, no 
timeout is set
+   * @return if the query has terminated
+   */
+  private def handleStreamingAwaitTermination(
+      query: StreamingQuery,
+      timeoutMs: Option[Long]): Boolean = {
+    // How often to check if RPC is cancelled and call awaitTermination()
+    val awaitTerminationIntervalMs = 10000
+
+    val hasTimeout = timeoutMs.isDefined
+    var timeoutLeftMs = timeoutMs.getOrElse(Long.MaxValue)
+    require(timeoutLeftMs > 0, "Timeout has to be positive")
+
+    val grpcContext = Context.current
+    while (!grpcContext.isCancelled) {
+      val awaitTimeMs = if (hasTimeout) {
+        math.min(awaitTerminationIntervalMs, timeoutLeftMs)
+      } else {
+        awaitTerminationIntervalMs
+      }
+
+      val terminated = query.awaitTermination(awaitTimeMs)
+      if (terminated) {
+        return true
+      }
+
+      if (hasTimeout) {
+        timeoutLeftMs -= awaitTerminationIntervalMs

Review Comment:
   Subtract actual time elapsed. 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2597,6 +2600,50 @@ class SparkConnectPlanner(val session: SparkSession) {
         .build())
   }
 
+  /**
+   * A helper function to handle streaming awaitTermination(). 
awaitTermination() can be a long
+   * running command. In this function, we periodically check if the RPC call 
has been cancelled.
+   * If so, we can stop the operation and release resources early.
+   * @param query the query waits to be terminated
+   * @param timeoutMs optional. Timeout to wait for termination. If None, no 
timeout is set
+   * @return if the query has terminated
+   */
+  private def handleStreamingAwaitTermination(
+      query: StreamingQuery,
+      timeoutMs: Option[Long]): Boolean = {
+    // How often to check if RPC is cancelled and call awaitTermination()
+    val awaitTerminationIntervalMs = 10000
+
+    val hasTimeout = timeoutMs.isDefined

Review Comment:
   This is not needed. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to