This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e934c43b7c58 [SPARK-54585][SS] Fix State Store rollback when thread is 
in interrupted state
e934c43b7c58 is described below

commit e934c43b7c58efdcd6e1df49cbb39803226fa17c
Author: Dylan Wong <[email protected]>
AuthorDate: Wed Dec 10 09:11:56 2025 -0800

    [SPARK-54585][SS] Fix State Store rollback when thread is in interrupted 
state
    
    ### What changes were proposed in this pull request?
    
    1. Modifies `ChecksumCancellableFSDataOutputStream.cancel()` to cancel both 
the main stream and checksum stream synchronously instead of using Futures with 
awaitResult.
    
    2. Moves `changelogWriter.foreach(_.abort())` and `changelogWriter = None` 
in a try finally block within `RocksDB.rollback()`.
    
    ### Why are the changes needed?
    
    For fix 1:
    
    When cancel() is called while the thread is in an interrupted state (e.g., 
during task cancellation), the previous implementation would fail. The code 
submitted Futures to cancel each stream, then called awaitResult() to wait for 
completion. However, awaitResult() checks the thread's interrupt flag and 
throws InterruptedException immediately if the thread is interrupted.
    
    For fix 2:
    
    Consider the case where `abort()` is called on `RocksDBStateStoreProvider`. 
This calls `rollback()` on the `RocksDB` instance, which in turn calls 
`changelogWriter.foreach(_.abort())` and then sets `changelogWriter = None`.
    
    However, if `changelogWriter.abort()` throws an exception, the finally 
block still sets `backingFileStream` and `compressedStream` to `null`. The 
exception propagates, and we never reach the line that sets `changelogWriter = 
None`.
    
    This leaves the RocksDB instance in an inconsistent state:
    - changelogWriter = Some(changelogWriterWeAttemptedToAbort)
    - changelogWriterWeAttemptedToAbort.backingFileStream = null
    - changelogWriterWeAttemptedToAbort.compressedStream = null
    
    Now consider calling `RocksDB.load()` again. This calls 
`replayChangelog()`, which calls `put()`, which calls `changelogWriter.put()`. 
At this point, the assertion `assert(compressedStream != null)` fails, causing 
an exception while loading the StateStore.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added test `"SPARK-54585: Interrupted task calling rollback does not throw 
an exception"` which simulates the case when a thread in the interrupted state 
and begins a rollback
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53313 from dylanwong250/SPARK-54585.
    
    Authored-by: Dylan Wong <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../ChecksumCheckpointFileManager.scala            | 15 +++---
 .../sql/execution/streaming/state/RocksDB.scala    | 28 ++++++----
 .../RocksDBCheckpointFailureInjectionSuite.scala   | 60 ++++++++++++++++++++++
 3 files changed, 84 insertions(+), 19 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
index 637d11ad890b..2de429ee1076 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
@@ -39,6 +39,7 @@ import org.apache.spark.internal.LogKeys.{CHECKSUM, 
NUM_BYTES, PATH, TIMEOUT}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import 
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager.CancellableFSDataOutputStream
 import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.Utils
 
 /** Information about the creator of the checksum file. Useful for debugging */
 case class ChecksumFileCreatorInfo(
@@ -500,16 +501,14 @@ class ChecksumCancellableFSDataOutputStream(
   @volatile private var closed = false
 
   override def cancel(): Unit = {
-    val mainFuture = Future {
+    // Cancel both streams synchronously rather than using futures. If the 
current thread is
+    // interrupted and we call this method, scheduling work on futures would 
immediately throw
+    // InterruptedException leaving the streams in an inconsistent state.
+    Utils.tryWithSafeFinally {
       mainStream.cancel()
-    }(uploadThreadPool)
-
-    val checksumFuture = Future {
+    } {
       checksumStream.cancel()
-    }(uploadThreadPool)
-
-    awaitResult(mainFuture, Duration.Inf)
-    awaitResult(checksumFuture, Duration.Inf)
+    }
   }
 
   override def close(): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 8a2ed6d9a529..c92c5017cada 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -1650,17 +1650,23 @@ class RocksDB(
    * Drop uncommitted changes, and roll back to previous version.
    */
   def rollback(): Unit = {
-    numKeysOnWritingVersion = numKeysOnLoadedVersion
-    numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
-    loadedVersion = -1L
-    lastCommitBasedStateStoreCkptId = None
-    lastCommittedStateStoreCkptId = None
-    loadedStateStoreCkptId = None
-    sessionStateStoreCkptId = None
-    lineageManager.clear()
-    changelogWriter.foreach(_.abort())
-    // Make sure changelogWriter gets recreated next time.
-    changelogWriter = None
+    logInfo(
+      log"Rolling back uncommitted changes on version 
${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
+    try {
+      numKeysOnWritingVersion = numKeysOnLoadedVersion
+      numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
+      loadedVersion = -1L
+      lastCommitBasedStateStoreCkptId = None
+      lastCommittedStateStoreCkptId = None
+      loadedStateStoreCkptId = None
+      sessionStateStoreCkptId = None
+      lineageManager.clear()
+      changelogWriter.foreach(_.abort())
+    } finally {
+      // Make sure changelogWriter gets recreated next time even if the 
changelogWriter aborts with
+      // an exception.
+      changelogWriter = None
+    }
     logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
index 0b9690ee7277..c48b492e27c2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
@@ -70,6 +70,10 @@ class RocksDBCheckpointFailureInjectionSuite extends 
StreamTest
 
   implicit def toArray(str: String): Array[Byte] = if (str != null) 
str.getBytes else null
 
+  implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new 
String(bytes) else null
+
+  def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), 
toStr(kv.value))
+
   case class FailureConf(ifEnableStateStoreCheckpointIds: Boolean, fileType: 
String) {
     override def toString: String = {
       s"ifEnableStateStoreCheckpointIds = $ifEnableStateStoreCheckpointIds, " +
@@ -824,6 +828,62 @@ class RocksDBCheckpointFailureInjectionSuite extends 
StreamTest
     }
   }
 
+  /**
+   * Test that verifies that when a task is interrupted, the store's 
rollback() method does not
+   * throw an exception and the store can still be used after the rollback.
+   */
+  test("SPARK-54585: Interrupted task calling rollback does not throw an 
exception") {
+    val hadoopConf = new Configuration()
+    hadoopConf.set(
+      STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key,
+      fileManagerClassName
+    )
+    withTempDirAllowFailureInjection { (remoteDir, _) =>
+      val sqlConf = new SQLConf()
+      
sqlConf.setConfString("spark.sql.streaming.checkpoint.fileChecksum.enabled", 
"true")
+      val rocksdbChangelogCheckpointingConfKey =
+        RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled"
+      sqlConf.setConfString(rocksdbChangelogCheckpointingConfKey, "true")
+      val conf = RocksDBConf(StateStoreConf(sqlConf))
+
+      withDB(
+        remoteDir.getAbsolutePath,
+        version = 0,
+        conf = conf,
+        hadoopConf = hadoopConf
+      ) { db =>
+        db.put("key0", "value0")
+        val checkpointId1 = commitAndGetCheckpointId(db)
+
+        db.load(1, checkpointId1)
+        db.put("key1", "value1")
+        val checkpointId2 = commitAndGetCheckpointId(db)
+
+        db.load(2, checkpointId2)
+        db.put("key2", "value2")
+
+        // Simulate what happens when a task is killed, the thread's interrupt 
flag is set.
+        // This replicates the scenario where TaskContext.markTaskFailed() is 
called and
+        // the task failure listener invokes RocksDBStateStore.abort() -> 
rollback().
+        Thread.currentThread().interrupt()
+
+        // rollback() should not throw an exception
+        db.rollback()
+
+        // Clear the interrupt flag for subsequent operations
+        Thread.interrupted()
+
+        // Reload the store and insert a new value
+        db.load(2, checkpointId2)
+        db.put("key3", "value3")
+
+        // Verify the store has the correct values
+        assert(db.iterator().map(toStr).toSet ===
+          Set(("key0", "value0"), ("key1", "value1"), ("key3", "value3")))
+      }
+    }
+  }
+
   def commitAndGetCheckpointId(db: RocksDB): Option[String] = {
     val (v, ci) = db.commit()
     ci.stateStoreCkptId


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to