WweiL commented on code in PR #42460:
URL: https://github.com/apache/spark/pull/42460#discussion_r1293777160


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2850,16 +2857,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       writer.foreachBatch(foreachBatchFn)
     }
 
-    val query = writeOp.getPath match {
-      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
-      case "" => writer.start()
-      case path => writer.start(path)
-    }
+    val query =
+      try {
+        writeOp.getPath match {
+          case "" if writeOp.hasTableName => 
writer.toTable(writeOp.getTableName)
+          case "" => writer.start()
+          case path => writer.start(path)
+        }
+      } catch {
+        case NonFatal(ex) => // Failed to start the query, clean up foreach 
runner if any.

Review Comment:
   Can you explain why do we only catch `NonFatal` here?



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -2850,16 +2857,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       writer.foreachBatch(foreachBatchFn)
     }
 
-    val query = writeOp.getPath match {
-      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
-      case "" => writer.start()
-      case path => writer.start(path)
-    }
+    val query =
+      try {
+        writeOp.getPath match {
+          case "" if writeOp.hasTableName => 
writer.toTable(writeOp.getTableName)
+          case "" => writer.start()
+          case path => writer.start(path)
+        }
+      } catch {
+        case NonFatal(ex) => // Failed to start the query, clean up foreach 
runner if any.
+          foreachBatchRunnerCleaner.foreach(_.close())
+          throw ex
+      }
 
     // Register the new query so that the session and query references are 
cached.
-    SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
-      sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, 
session),
-      query = query)
+    
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder,
 query)
+    // Register the runner with the query if Python foreachBatch is enabled.
+    foreachBatchRunnerCleaner.foreach { cleaner =>
+      sessionHolder.streamingRunnerCleanerCache.registerCleanerForQuery(query, 
cleaner)
+    }
+    // Register the new query so that the session and query references are 
cached.

Review Comment:
   I think this comment needs to be deleted



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala:
##########
@@ -153,7 +158,7 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
     logDebug(s"Expiring session with userId: $userId and sessionId: 
$sessionId")
     artifactManager.cleanUpResources()
     eventManager.postClosed()
-
+    streamingRunnerCleanerCache.cleanUpAll()

Review Comment:
   I think it's better to put this below 
   ```
   SparkConnectService.streamingSessionManager.cleanupRunningQueries(this)
   ```
   so we first stop the queries, then stop the python processes.
   
   That way in case there are still queries running, it would prevent a lot of 
errors thrown. I think it's better even if the errors are caught anyways



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala:
##########
@@ -31,6 +38,17 @@ object StreamingForeachBatchHelper extends Logging {
 
   type ForeachBatchFnType = (DataFrame, Long) => Unit
 
+  case class RunnerCleaner(runner: StreamingPythonRunner) extends 
AutoCloseable {
+    override def close(): Unit = {
+      try runner.stop()
+      catch {
+        case NonFatal(ex) =>
+          logWarning("Error while stopping streaming Python worker", ex)
+        // Exception is not propagated.

Review Comment:
   indentation?
   
   I'm not sure what "exception is not propagated" means, but [PythonRunner 
also just logs 
this](https://github.com/apache/spark/blob/bf1dc9fd650747baa6abf16b5f7c13652362556d/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala#L214-L221)
   
   Also may I know if there is any specific reason we only catch nonFatal 
errors?



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala:
##########
@@ -180,13 +185,16 @@ private[connect] class SparkConnectStreamingQueryCache(
           case Some(_) => // Inactive query waiting for expiration. Do nothing.
             logInfo(s"Waiting for the expiration for $id in session 
${v.sessionId}")
 
-          case None => // Active query, check if it is stopped. Keep the 
session alive.
+          case None => // Active query, check if it is stopped. Enable timeout 
if it is stopped.
             val isActive = v.query.isActive && 
Option(v.session.streams.get(id)).nonEmpty
 
             if (!isActive) {
               logInfo(s"Marking query $id in session ${v.sessionId} inactive.")
               val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
               queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
+              // To consider: Clean up any runner registered for this query 
with the session holder

Review Comment:
   Should we file a ticket for this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to