This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 934134e99aed [SPARK-50163][SS] Fix the RocksDB extra acquireLock 
release due to the completion listener
934134e99aed is described below

commit 934134e99aeda36f7795c46e73ab6a017d3113ad
Author: Livia Zhu <[email protected]>
AuthorDate: Sun Nov 3 07:40:03 2024 +0900

    [SPARK-50163][SS] Fix the RocksDB extra acquireLock release due to the 
completion listener
    
    ### What changes were proposed in this pull request?
    
    Adds a new `releaseForThread(opType, threadRef)` method that only releases 
the RocksDB acquireLock if the provided threadRef is the same as the threadRef 
in `acquiredThreadInfo`. Also, modify the task completion listener created in 
`acquire` to use the new method so that it doesn't accidentally release the 
lock if it's owned by another thread.
    
    Also some small changes to usages of acquire() and release() in `rollback` 
and `close` to follow this pattern to avoid extraneous or missed releases due 
to exceptions/interrupts:
    ```
    acquire()
    try {
      ...
    } finally {
      release()
    }
    ```
    
    ### Why are the changes needed?
    
    There currently exists a race condition in RocksDB where the completion 
listener may release a lock validly held by another thread, resulting in a 
thread-unsafe condition. We have seen this bug in production.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit tests in RocksDB.scala, all other unit tests continue to pass as 
expected.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Github Co-pilot 1.5.25.10
    
    Closes #48697 from liviazhu-db/liviazhu-db/rocksdb-acquirelock-fix.
    
    Authored-by: Livia Zhu <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../sql/execution/streaming/state/RocksDB.scala    |  85 ++++++---
 .../execution/streaming/state/RocksDBSuite.scala   | 190 ++++++++++++++++++++-
 2 files changed, 249 insertions(+), 26 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index aeac5ea71a2e..544035e11785 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -743,17 +743,20 @@ class RocksDB(
    */
   def rollback(): Unit = {
     acquire(RollbackStore)
-    numKeysOnWritingVersion = numKeysOnLoadedVersion
-    loadedVersion = -1L
-    lastCommitBasedStateStoreCkptId = None
-    lastCommittedStateStoreCkptId = None
-    loadedStateStoreCkptId = None
-    sessionStateStoreCkptId = None
-    changelogWriter.foreach(_.abort())
-    // Make sure changelogWriter gets recreated next time.
-    changelogWriter = None
-    release(RollbackStore)
-    logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
+    try {
+      numKeysOnWritingVersion = numKeysOnLoadedVersion
+      loadedVersion = -1L
+      lastCommitBasedStateStoreCkptId = None
+      lastCommittedStateStoreCkptId = None
+      loadedStateStoreCkptId = None
+      sessionStateStoreCkptId = None
+      changelogWriter.foreach(_.abort())
+      // Make sure changelogWriter gets recreated next time.
+      changelogWriter = None
+      logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
+    } finally {
+      release(RollbackStore)
+    }
   }
 
   def doMaintenance(): Unit = {
@@ -787,9 +790,9 @@ class RocksDB(
 
   /** Release all resources */
   def close(): Unit = {
+    // Acquire DB instance lock and release at the end to allow for 
synchronized access
+    acquire(CloseStore)
     try {
-      // Acquire DB instance lock and release at the end to allow for 
synchronized access
-      acquire(CloseStore)
       closeDB()
 
       readOptions.close()
@@ -903,8 +906,8 @@ class RocksDB(
    */
   def metricsOpt: Option[RocksDBMetrics] = {
     var rocksDBMetricsOpt: Option[RocksDBMetrics] = None
+    acquire(ReportStoreMetrics)
     try {
-      acquire(ReportStoreMetrics)
       rocksDBMetricsOpt = recordedMetrics
     } catch {
       case ex: Exception =>
@@ -956,7 +959,7 @@ class RocksDB(
       acquiredThreadInfo = newAcquiredThreadInfo
       // Add a listener to always release the lock when the task (if active) 
completes
       Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
-        _ => this.release(StoreTaskCompletionListener)
+        _ => this.release(StoreTaskCompletionListener, 
Some(newAcquiredThreadInfo))
       })
       logInfo(log"RocksDB instance was acquired by ${MDC(LogKeys.THREAD, 
acquiredThreadInfo)} " +
         log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
@@ -965,16 +968,44 @@ class RocksDB(
 
   /**
    * Function to release RocksDB instance lock that allows for synchronized 
access to the state
-   * store instance
+   * store instance. Optionally provide a thread to check against, and release 
only if provided
+   * thread is the one that acquired the lock.
    *
    * @param opType - operation type releasing the lock
+   * @param releaseForThreadOpt - optional thread to check against acquired 
thread
    */
-  private def release(opType: RocksDBOpType): Unit = acquireLock.synchronized {
+  private def release(
+      opType: RocksDBOpType,
+      releaseForThreadOpt: Option[AcquiredThreadInfo] = None): Unit = 
acquireLock.synchronized {
     if (acquiredThreadInfo != null) {
-      logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD,
-        acquiredThreadInfo)} " + log"for opType=${MDC(LogKeys.OP_TYPE, 
opType.toString)}")
-      acquiredThreadInfo = null
-      acquireLock.notifyAll()
+      val release = releaseForThreadOpt match {
+        case Some(releaseForThread) if releaseForThread.threadRef.get.isEmpty 
=>
+          logInfo(log"Thread reference is empty when attempting to release 
for" +
+            log" opType=${MDC(LogKeys.OP_TYPE, opType.toString)}, ignoring 
release." +
+            log" Lock is held by ${MDC(LogKeys.THREAD, acquiredThreadInfo)}")
+          false
+        // NOTE: we compare the entire acquiredThreadInfo object to ensure 
that we are
+        // releasing not only for the right thread but the right task as well. 
This is
+        // inconsistent with the logic for acquire which uses only the thread 
ID, consider
+        // updating this in future.
+        case Some(releaseForThread) if acquiredThreadInfo != releaseForThread 
=>
+          logInfo(log"Thread info for release" +
+            log" ${MDC(LogKeys.THREAD, releaseForThreadOpt.get)}" +
+            log" does not match the acquired thread when attempting to" +
+            log" release for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}, 
ignoring release." +
+            log" Lock is held by ${MDC(LogKeys.THREAD, acquiredThreadInfo)}")
+          false
+        case _ => true
+      }
+
+      if (release) {
+        logInfo(log"RocksDB instance was released by " +
+          log"${MDC(LogKeys.THREAD, AcquiredThreadInfo())}. " +
+          log"acquiredThreadInfo: ${MDC(LogKeys.THREAD, acquiredThreadInfo)} " 
+
+          log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
+        acquiredThreadInfo = null
+        acquireLock.notifyAll()
+      }
     }
   }
 
@@ -1001,6 +1032,11 @@ class RocksDB(
     }
   }
 
+  private[state] def getAcquiredThreadInfo(): Option[AcquiredThreadInfo] =
+      acquireLock.synchronized {
+    Option(acquiredThreadInfo).map(_.copy())
+  }
+
   /** Create a native RocksDB logger that forwards native logs to log4j with 
correct log levels. */
   private def createLogger(): Logger = {
     val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) {
@@ -1456,10 +1492,9 @@ object RocksDBNativeHistogram {
   }
 }
 
-case class AcquiredThreadInfo() {
-  val threadRef: WeakReference[Thread] = new 
WeakReference[Thread](Thread.currentThread())
-  val tc: TaskContext = TaskContext.get()
-
+case class AcquiredThreadInfo(
+    threadRef: WeakReference[Thread] = new 
WeakReference[Thread](Thread.currentThread()),
+    tc: TaskContext = TaskContext.get()) {
   override def toString(): String = {
     val taskStr = if (tc != null) {
       val taskDetails =
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 8fde216c1441..3455e4f8387c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -34,7 +34,7 @@ import org.rocksdb.CompressionType
 import org.scalactic.source.Position
 import org.scalatest.Tag
 
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.execution.streaming.{CreateAtomicTestManager, 
FileSystemBasedCheckpointFileManager}
 import 
org.apache.spark.sql.execution.streaming.CheckpointFileManager.{CancellableFSDataOutputStream,
 RenameBasedFSDataOutputStream}
@@ -2240,6 +2240,194 @@ class RocksDBSuite extends 
AlsoTestWithChangelogCheckpointingEnabled with Shared
     }
   }
 
+  test("Rocks DB task completion listener does not double unlock 
acquireThread") {
+    // This test verifies that a thread that locks then unlocks the db and then
+    // fires a completion listener (Thread 1) does not unlock the lock validly
+    // acquired by another thread (Thread 2).
+    //
+    // Timeline of this test (* means thread is active):
+    // STATE | MAIN             | THREAD 1         | THREAD 2         |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 0.    | wait for s3      | *load, commit    | wait for s1      |
+    //       |                  | *signal s1       |                  |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 1.    |                  | wait for s2      | *load, signal s2 |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 2.    |                  | *task complete   | wait for s4      |
+    //       |                  | *signal s3, END  |                  |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 3.    | *verify locked   |                  |                  |
+    //       | *signal s4       |                  |                  |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 4.    | wait for s5      |                  | *commit          |
+    //       |                  |                  | *signal s5, END  |
+    // ------| ---------------- | ---------------- | ---------------- |
+    // 5.    | *close db, END   |                  |                  |
+    //
+    // NOTE: state 4 and 5 are only for cleanup
+
+    // Create a custom ExecutionContext with 3 threads
+    implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
+      ThreadUtils.newDaemonFixedThreadPool(3, "pool-thread-executor"))
+    val stateLock = new Object()
+    var state = 0
+
+    withTempDir { dir =>
+      val remoteDir = dir.getCanonicalPath
+      val db = new RocksDB(
+        remoteDir,
+        conf = dbConf,
+        localRootDir = Utils.createTempDir(),
+        hadoopConf = new Configuration(),
+        loggingId = s"[Thread-${Thread.currentThread.getId}]",
+        useColumnFamilies = false
+      )
+      try {
+        Future { // THREAD 1
+          // Set thread 1's task context so that it is not a clone
+          // of the main thread's taskContext, which will end if the
+          // task is marked as complete
+          val taskContext = TaskContext.empty()
+          TaskContext.setTaskContext(taskContext)
+
+          stateLock.synchronized {
+            // -------------------- STATE 0 --------------------
+            // Simulate a task that loads and commits, db should be unlocked 
after
+            db.load(0)
+            db.put("a", "1")
+            db.commit()
+            // Signal that we have entered state 1
+            state = 1
+            stateLock.notifyAll()
+
+            // -------------------- STATE 2 --------------------
+            // Wait until we have entered state 2 (thread 2 has loaded db and 
acquired lock)
+            while (state != 2) {
+              stateLock.wait()
+            }
+
+            // thread 1's task context is marked as complete and signal
+            // that we have entered state 3
+            // At this point, thread 2 should still hold the DB lock.
+            taskContext.markTaskCompleted(None)
+            state = 3
+            stateLock.notifyAll()
+          }
+        }
+
+        Future { // THREAD 2
+          // Set thread 2's task context so that it is not a clone of thread 
1's
+          // so it won't be marked as complete
+          val taskContext = TaskContext.empty()
+          TaskContext.setTaskContext(taskContext)
+
+          stateLock.synchronized {
+            // -------------------- STATE 1 --------------------
+            // Wait until we have entered state 1 (thread 1 finished loading 
and committing)
+            while (state != 1) {
+              stateLock.wait()
+            }
+
+            // Load the db and signal that we have entered state 2
+            db.load(1)
+            assertAcquiredThreadIsCurrentThread(db)
+            state = 2
+            stateLock.notifyAll()
+
+            // -------------------- STATE 4 --------------------
+            // Wait until we have entered state 4 (thread 1 completed and
+            // main thread confirmed that lock is held)
+            while (state != 4) {
+              stateLock.wait()
+            }
+
+            // Ensure we still have the lock
+            assertAcquiredThreadIsCurrentThread(db)
+
+            // commit and signal that we have entered state 5
+            db.commit()
+            state = 5
+            stateLock.notifyAll()
+          }
+        }
+
+        // MAIN THREAD
+        stateLock.synchronized {
+          // -------------------- STATE 3 --------------------
+          // Wait until we have entered state 3 (thread 1 is complete)
+          while (state != 3) {
+            stateLock.wait()
+          }
+
+          // Verify that the lock is being held
+          val threadInfo = db.getAcquiredThreadInfo()
+          assert(threadInfo.nonEmpty, s"acquiredThreadInfo was None when it 
should be Some")
+
+          // Signal that we have entered state 4 (thread 2 can now release 
lock)
+          state = 4
+          stateLock.notifyAll()
+
+          // -------------------- STATE 5 --------------------
+          // Wait until we have entered state 5 (thread 2 has released lock)
+          // so that we can clean up
+          while (state != 5) {
+            stateLock.wait()
+          }
+        }
+      } finally {
+        db.close()
+      }
+    }
+  }
+
+  test("RocksDB task completion listener correctly releases for failed task") {
+    // This test verifies that a thread that locks the DB and then fails
+    // can rely on the completion listener to release the lock.
+
+    // Create a custom ExecutionContext with 1 thread
+    implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
+      ThreadUtils.newDaemonSingleThreadExecutor("single-thread-executor"))
+    val timeout = 5.seconds
+
+    withTempDir { dir =>
+      val remoteDir = dir.getCanonicalPath
+      withDB(remoteDir) { db =>
+        // Release the lock acquired by withDB
+        db.commit()
+
+        // New task that will load and then complete with failure
+        val fut = Future {
+          val taskContext = TaskContext.empty()
+          TaskContext.setTaskContext(taskContext)
+
+          db.load(0)
+          assertAcquiredThreadIsCurrentThread(db)
+
+          // Task completion listener should unlock
+          taskContext.markTaskCompleted(
+            Some(new SparkException("Task failure injection")))
+        }
+
+        ThreadUtils.awaitResult(fut, timeout)
+
+        // Assert that db is not locked
+        val threadInfo = db.getAcquiredThreadInfo()
+        assert(threadInfo.isEmpty, s"acquiredThreadInfo should be None but was 
$threadInfo")
+      }
+    }
+  }
+
+  private def assertAcquiredThreadIsCurrentThread(db: RocksDB): Unit = {
+    val threadInfo = db.getAcquiredThreadInfo()
+    assert(threadInfo != None,
+      "acquired thread info should not be null after load")
+    val threadId = threadInfo.get.threadRef.get.get.getId
+    assert(
+      threadId == Thread.currentThread().getId,
+      s"acquired thread should be curent thread 
${Thread.currentThread().getId} " +
+        s"after load but was $threadId")
+  }
+
   private def dbConf = RocksDBConf(StateStoreConf(SQLConf.get.clone()))
 
   def withDB[T](


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to