juliuszsompolski commented on code in PR #48208:
URL: https://github.com/apache/spark/pull/48208#discussion_r1773106069


##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -139,23 +129,50 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
         }
       }
     } catch {
-      ErrorUtils.handleError(
-        "execute",
-        executeHolder.responseObserver,
-        executeHolder.sessionHolder.userId,
-        executeHolder.sessionHolder.sessionId,
-        Some(executeHolder.eventsManager),
-        interrupted)
+      case e: Throwable if state.getAcquire() != 
ThreadState.startedInterrupted =>
+        ErrorUtils.handleError(
+          "execute",
+          executeHolder.responseObserver,
+          executeHolder.sessionHolder.userId,
+          executeHolder.sessionHolder.sessionId,
+          Some(executeHolder.eventsManager),
+          false)(e)
+    } finally {
+      // Make sure to transition to completed in order to prevent the thread 
from being interrupted
+      // afterwards.
+      var currentState = state.getAcquire()
+      while (currentState == ThreadState.started ||
+        currentState == ThreadState.startedInterrupted) {
+        val interrupted = currentState == ThreadState.startedInterrupted
+        val prevState = state.compareAndExchangeRelease(currentState, 
ThreadState.completed)
+        if (prevState == currentState) {
+          if (interrupted) {
+            try {
+              ErrorUtils.handleError(
+                "execute",
+                executeHolder.responseObserver,
+                executeHolder.sessionHolder.userId,
+                executeHolder.sessionHolder.sessionId,
+                Some(executeHolder.eventsManager),
+                true)(new SparkSQLException("OPERATION_CANCELED", Map.empty))
+            } finally {
+              executeHolder.cleanup()
+            }
+          }
+          return
+        }
+        currentState = prevState
+      }
     }
   }
 
   // Inner executeInternal is wrapped by execute() for error handling.
-  private def executeInternal() = {
-    // synchronized - check if already got interrupted while starting.
-    lock.synchronized {
-      if (interrupted) {
-        throw new InterruptedException()
-      }
+  private def executeInternal(): Unit = {
+    val prevState = state.compareAndExchangeRelease(ThreadState.notStarted, 
ThreadState.started)
+    if (prevState != ThreadState.notStarted && prevState != 
ThreadState.started) {
+      // Silently return, expecting that the caller would handle the 
interruption.
+      assert(prevState != ThreadState.completed)

Review Comment:
   `assert(prevState == ThreadState.interrupted)` should be true here, because 
the only way this could happen is if it got interrupted right as it was 
starting? (it should not be `startedInterrupted` here?)
   Maybe extend this comment with more explanation?



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -32,76 +32,73 @@ import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.Utils
 
 /**
  * This class launches the actual execution in an execution thread. The 
execution pushes the
  * responses to a ExecuteResponseObserver in executeHolder.
  */
 private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) 
extends Logging {
 
-  private val promise: Promise[Unit] = Promise[Unit]()
+  /** The thread state. */
+  private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
 
   // The newly created thread will inherit all InheritableThreadLocals used by 
Spark,
   // e.g. SparkContext.localProperties. If considering implementing a 
thread-pool,
   // forwarding of thread locals needs to be taken into account.
-  private val executionThread: ExecutionThread = new ExecutionThread(promise)
-
-  private var started: Boolean = false
-
-  private var interrupted: Boolean = false
-
-  private var completed: Boolean = false
-
-  private val lock = new Object
+  private val executionThread: ExecutionThread = new ExecutionThread()
 
   /** Launches the execution in a background thread, returns immediately. */
   private[connect] def start(): Unit = {
-    lock.synchronized {
-      assert(!started)
-      // Do not start if already interrupted.
-      if (!interrupted) {
-        executionThread.start()
-        started = true
-      }
-    }
-  }
+    if (state.getAcquire() == ThreadState.notStarted) {
+      executionThread.start()
 
-  /**
-   * Register a callback that gets executed after completion/interruption of 
the execution thread.
-   */
-  private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit 
= {
-    
promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext)
+      // If the thread is started earlier than this, the thread will change 
the state itself.
+      state.compareAndExchangeRelease(ThreadState.notStarted, 
ThreadState.started)
+    }
   }
 
   /**
-   * Interrupt the executing thread.
+   * Interrupts the execution thread if the thread is running and has yet to 
be completed.
+   *
    * @return
-   *   true if it was not interrupted before, false if it was already 
interrupted or completed.
+   *   true if the thread is running and interrupted.
    */
   private[connect] def interrupt(): Boolean = {
-    lock.synchronized {
-      if (!started && !interrupted) {
-        // execution thread hasn't started yet, and will not be started.
-        // handle the interrupted error here directly.
-        interrupted = true
-        ErrorUtils.handleError(
-          "execute",
-          executeHolder.responseObserver,
-          executeHolder.sessionHolder.userId,
-          executeHolder.sessionHolder.sessionId,
-          Some(executeHolder.eventsManager),
-          interrupted)(new SparkSQLException("OPERATION_CANCELED", Map.empty))
-        true
-      } else if (!interrupted && !completed) {
-        // checking completed prevents sending interrupt onError after 
onCompleted
-        interrupted = true
-        executionThread.interrupt()
-        true
+    var currentState = state.getAcquire()
+    while (currentState == ThreadState.notStarted || currentState == 
ThreadState.started) {
+      val newState = if (currentState == ThreadState.notStarted) {
+        ThreadState.interrupted
       } else {
-        false
+        ThreadState.startedInterrupted
       }
+
+      val prevState = state.compareAndExchangeRelease(currentState, newState)
+      if (prevState == currentState) {
+        if (prevState == ThreadState.notStarted) {
+          // The execution thread has not been started, and will never be 
started.

Review Comment:
   If this happens before the thread reaches the transition in executeInternal, 
executeInternal will silently exit?
   So it's techically not true that the thread has not been started; it might 
have been started but not reaching that transition?



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -32,76 +32,73 @@ import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.Utils
 
 /**
  * This class launches the actual execution in an execution thread. The 
execution pushes the
  * responses to a ExecuteResponseObserver in executeHolder.
  */
 private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) 
extends Logging {
 
-  private val promise: Promise[Unit] = Promise[Unit]()
+  /** The thread state. */
+  private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
 
   // The newly created thread will inherit all InheritableThreadLocals used by 
Spark,
   // e.g. SparkContext.localProperties. If considering implementing a 
thread-pool,
   // forwarding of thread locals needs to be taken into account.
-  private val executionThread: ExecutionThread = new ExecutionThread(promise)
-
-  private var started: Boolean = false
-
-  private var interrupted: Boolean = false
-
-  private var completed: Boolean = false
-
-  private val lock = new Object
+  private val executionThread: ExecutionThread = new ExecutionThread()
 
   /** Launches the execution in a background thread, returns immediately. */
   private[connect] def start(): Unit = {
-    lock.synchronized {
-      assert(!started)
-      // Do not start if already interrupted.
-      if (!interrupted) {
-        executionThread.start()
-        started = true
-      }
-    }
-  }
+    if (state.getAcquire() == ThreadState.notStarted) {
+      executionThread.start()
 
-  /**
-   * Register a callback that gets executed after completion/interruption of 
the execution thread.
-   */
-  private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit 
= {
-    
promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext)
+      // If the thread is started earlier than this, the thread will change 
the state itself.
+      state.compareAndExchangeRelease(ThreadState.notStarted, 
ThreadState.started)

Review Comment:
   Is this needed here, or is the transition to started at the start of 
executeInternal enough?



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -32,76 +32,73 @@ import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.Utils
 
 /**
  * This class launches the actual execution in an execution thread. The 
execution pushes the
  * responses to a ExecuteResponseObserver in executeHolder.
  */
 private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) 
extends Logging {
 
-  private val promise: Promise[Unit] = Promise[Unit]()
+  /** The thread state. */
+  private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
 
   // The newly created thread will inherit all InheritableThreadLocals used by 
Spark,
   // e.g. SparkContext.localProperties. If considering implementing a 
thread-pool,
   // forwarding of thread locals needs to be taken into account.
-  private val executionThread: ExecutionThread = new ExecutionThread(promise)
-
-  private var started: Boolean = false
-
-  private var interrupted: Boolean = false
-
-  private var completed: Boolean = false
-
-  private val lock = new Object
+  private val executionThread: ExecutionThread = new ExecutionThread()
 
   /** Launches the execution in a background thread, returns immediately. */
   private[connect] def start(): Unit = {
-    lock.synchronized {
-      assert(!started)
-      // Do not start if already interrupted.
-      if (!interrupted) {
-        executionThread.start()
-        started = true
-      }
-    }
-  }
+    if (state.getAcquire() == ThreadState.notStarted) {
+      executionThread.start()

Review Comment:
   The other possible state here is interrupted, if it got interrupted before 
started?



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -226,17 +243,13 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
               observedMetrics ++ accumulatedInPython))
       }
 
-      lock.synchronized {
-        // Synchronized before sending ResultComplete, and up until completing 
the result stream
-        // to prevent a situation in which a client of reattachable execution 
receives
-        // ResultComplete, and proceeds to send ReleaseExecute, and that 
triggers an interrupt
-        // before it finishes.
-
-        if (interrupted) {
-          // check if it got interrupted at the very last moment
-          throw new InterruptedException()
-        }
-        completed = true // no longer interruptible
+      // State transition should be atomic to prevent a situation in which a 
client of reattachable
+      // execution receives ResultComplete, and proceeds to send 
ReleaseExecute, and that triggers
+      // an interrupt before it finishes.
+      if (state.compareAndExchangeRelease(
+          ThreadState.started,
+          ThreadState.completed) == ThreadState.started) {
+        // Now, the execution cannot be interrupted.

Review Comment:
   Why do we not have to throw InterruptedException anymore in an else to this 
if?
   I guess because the check for this is now in a finally and not in catch?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to