rangadi commented on code in PR #42460:
URL: https://github.com/apache/spark/pull/42460#discussion_r1293963420


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2850,16 +2857,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       writer.foreachBatch(foreachBatchFn)
     }
 
-    val query = writeOp.getPath match {
-      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
-      case "" => writer.start()
-      case path => writer.start(path)
-    }
+    val query =
+      try {
+        writeOp.getPath match {
+          case "" if writeOp.hasTableName => 
writer.toTable(writeOp.getTableName)
+          case "" => writer.start()
+          case path => writer.start(path)
+        }
+      } catch {
+        case NonFatal(ex) => // Failed to start the query, clean up foreach 
runner if any.

Review Comment:
   That is usual best practice. The expectation is that any other error is a 
fatal error and is not good to catch. 
   One common error that is not caught by NonFatal is `OutOfMemoryError`.



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2850,16 +2857,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       writer.foreachBatch(foreachBatchFn)
     }
 
-    val query = writeOp.getPath match {
-      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
-      case "" => writer.start()
-      case path => writer.start(path)
-    }
+    val query =
+      try {
+        writeOp.getPath match {
+          case "" if writeOp.hasTableName => 
writer.toTable(writeOp.getTableName)
+          case "" => writer.start()
+          case path => writer.start(path)
+        }
+      } catch {
+        case NonFatal(ex) => // Failed to start the query, clean up foreach 
runner if any.

Review Comment:
   See spark style guide: 
https://github.com/databricks/scala-style-guide#exception-handling-try-vs-try



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala:
##########
@@ -31,6 +38,17 @@ object StreamingForeachBatchHelper extends Logging {
 
   type ForeachBatchFnType = (DataFrame, Long) => Unit
 
+  case class RunnerCleaner(runner: StreamingPythonRunner) extends 
AutoCloseable {
+    override def close(): Unit = {
+      try runner.stop()
+      catch {
+        case NonFatal(ex) =>
+          logWarning("Error while stopping streaming Python worker", ex)
+        // Exception is not propagated.

Review Comment:
   The indentation is due to scalafmt. I will move it up so that it is aligned 
better. 
   
   `NonFatal` is a good practice to catch only application error rather than 
some JVM related fatal errors. This is important especially while swallowing 
the exceptions. See Spark codestyle guide: 
https://github.com/databricks/scala-style-guide#exception-handling-try-vs-try



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2850,16 +2857,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       writer.foreachBatch(foreachBatchFn)
     }
 
-    val query = writeOp.getPath match {
-      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
-      case "" => writer.start()
-      case path => writer.start(path)
-    }
+    val query =
+      try {
+        writeOp.getPath match {
+          case "" if writeOp.hasTableName => 
writer.toTable(writeOp.getTableName)
+          case "" => writer.start()
+          case path => writer.start(path)
+        }
+      } catch {
+        case NonFatal(ex) => // Failed to start the query, clean up foreach 
runner if any.
+          foreachBatchRunnerCleaner.foreach(_.close())
+          throw ex
+      }
 
     // Register the new query so that the session and query references are 
cached.
-    SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
-      sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, 
session),
-      query = query)
+    
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder,
 query)
+    // Register the runner with the query if Python foreachBatch is enabled.
+    foreachBatchRunnerCleaner.foreach { cleaner =>
+      sessionHolder.streamingRunnerCleanerCache.registerCleanerForQuery(query, 
cleaner)
+    }
+    // Register the new query so that the session and query references are 
cached.

Review Comment:
   Removed. Also updated the original comment.



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala:
##########
@@ -180,13 +185,16 @@ private[connect] class SparkConnectStreamingQueryCache(
           case Some(_) => // Inactive query waiting for expiration. Do nothing.
             logInfo(s"Waiting for the expiration for $id in session 
${v.sessionId}")
 
-          case None => // Active query, check if it is stopped. Keep the 
session alive.
+          case None => // Active query, check if it is stopped. Enable timeout 
if it is stopped.
             val isActive = v.query.isActive && 
Option(v.session.streams.get(id)).nonEmpty
 
             if (!isActive) {
               logInfo(s"Marking query $id in session ${v.sessionId} inactive.")
               val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
               queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
+              // To consider: Clean up any runner registered for this query 
with the session holder

Review Comment:
   Shall I do this in the PR? Testing might be tricky. Updated the comment to 
clarify that these would be cleaned up when the session expires 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala:
##########
@@ -153,7 +158,7 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
     logDebug(s"Expiring session with userId: $userId and sessionId: 
$sessionId")
     artifactManager.cleanUpResources()
     eventManager.postClosed()
-
+    streamingRunnerCleanerCache.cleanUpAll()

Review Comment:
   Updated. I was thinking about which one should go first. stopping queries 
after stopping the queries is better.
   The streaming listener event might fire after the worker is stopped, that is 
ok. 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala:
##########
@@ -111,17 +128,79 @@ object StreamingForeachBatchHelper extends Logging {
       logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret: 
$ret)")
     }
 
-    dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder)
+    (dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder), 
RunnerCleaner(runner))
   }
 
-  // TODO(SPARK-44433): Improve termination of Processes
-  //   The goal is that when a query is terminated, the python process 
associated with foreachBatch
-  //   should be terminated. One way to do that is by registering streaming 
query listener:
-  //   After pythonForeachBatchWrapper() is invoked by the SparkConnectPlanner.
-  //   At that time, we don't have the streaming queries yet.
-  //   Planner should call back into this helper with the query id when it 
starts it immediately
-  //   after. Save the query id to StreamingPythonRunner mapping. This mapping 
should be
-  //   part of the SessionHolder.
-  //   When a query is terminated, check the mapping and terminate any 
associated runner.
-  //   These runners should be terminated when a session is deleted (due to 
timeout, etc).
+  /**
+   * This manages cache from queries to cleaner for runners used for streaming 
queries. This is
+   * used in [[SessionHolder]].
+   */
+  class CleanerCache(sessionHolder: SessionHolder) {
+
+    private case class CacheKey(queryId: String, runId: String)
+
+    // Mapping from streaming (queryId, runId) to runner cleaner. Used for 
Python foreachBatch.
+    private val cleanerCache: ConcurrentMap[CacheKey, AutoCloseable] = new 
ConcurrentHashMap()
+
+    private lazy val streamingListener = { // Initialized on first registered 
query
+      val listener = new StreamingRunnerCleanerListener
+      sessionHolder.session.streams.addListener(listener)
+      logInfo(s"Registered runner clean up listener for session 
${sessionHolder.sessionId}")
+      listener
+    }
+
+    private[connect] def registerCleanerForQuery(
+        query: StreamingQuery,
+        cleaner: AutoCloseable): Unit = {
+
+      streamingListener // Access to initialize
+      val key = CacheKey(query.id.toString, query.runId.toString)
+
+      Option(cleanerCache.putIfAbsent(key, cleaner)) match {
+        case Some(_) =>
+          throw new IllegalStateException(s"Unexpected: a cleaner for query 
$key is already set")

Review Comment:
   This is an internal check. Should never happen (unlike 'SESSION_NOT_FOUND' 
which can be triggered by user sending a stale session id).  



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to