liviazhu commented on code in PR #53101:
URL: https://github.com/apache/spark/pull/53101#discussion_r2535713524


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -2563,6 +2563,82 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  test("SPARK-54389: State store loading fails when TaskContext has failed") {
+    // Timeline of this test (* means thread is active):
+    // STATE | MAIN THREAD            | STATE STORE THREAD          |
+    // ------| ---------------------- | --------------------------- |
+    // 1.    | wait for s2            | *set task context           |
+    //       |                        | *signal s2                  |
+    // ------| ---------------------- | --------------------------- |
+    // 2.    | *mark task failed      | wait for s3                 |
+    //       | *signal s3             |                             |
+    // ------| ---------------------- | --------------------------- |
+    // 3.    | wait thread return     | *call getStore().iterator() |
+    //       |                        | *exception thrown           |
+    // ------| ---------------------- | --------------------------- |
+    // 4.    | *verify exception, END |                             |
+
+    // Create a custom ExecutionContext with 2 threads
+    implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
+      ThreadUtils.newDaemonFixedThreadPool(2, "pool-thread-executor"))
+    val stateLock = new Object()
+    var state = 1
+    val timeout = 10.seconds
+
+    tryWithProviderResource(newStoreProvider()) { provider =>
+      val taskContext = TaskContext.empty()
+
+      val stateStoreFuture = Future { // STATE STORE THREAD
+        stateLock.synchronized {
+          // -------------------- STATE 1 --------------------
+          // Set the task context for this thread
+          TaskContext.setTaskContext(taskContext)
+
+          // Signal that we have entered state 2
+          state = 2
+          stateLock.notifyAll()
+
+          // -------------------- STATE 3 --------------------
+          // Wait until we have entered state 3 (main thread marked task as 
failed)
+          while (state != 3) {
+            stateLock.wait()
+          }
+
+          // Try to call getStore().iterator() which should trigger the error 
handling
+          provider.getStore(0).iterator()
+        }
+      }
+
+      val ex = new IllegalStateException("failure")
+      stateLock.synchronized {
+        // -------------------- STATE 2 --------------------
+        // Wait until we have entered state 2 (state store thread set task 
context)
+        while (state != 2) {

Review Comment:
   The loop is only going to be iterated through once, normally, since we 
`wait` until `notifyAll` is invoked (we keep it in a `while` loop to handle 
spurious wakeups). This is functionally equivalent to CountDownLatch, only with 
CountDownLatch we'll need 1 latch for every state I believe.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to