This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 934134e99aed [SPARK-50163][SS] Fix the RocksDB extra acquireLock
release due to the completion listener
934134e99aed is described below
commit 934134e99aeda36f7795c46e73ab6a017d3113ad
Author: Livia Zhu <[email protected]>
AuthorDate: Sun Nov 3 07:40:03 2024 +0900
[SPARK-50163][SS] Fix the RocksDB extra acquireLock release due to the
completion listener
### What changes were proposed in this pull request?
Adds a new `releaseForThread(opType, threadRef)` method that only releases
the RocksDB acquireLock if the provided threadRef is the same as the threadRef
in `acquiredThreadInfo`. Also, modify the task completion listener created in
`acquire` to use the new method so that it doesn't accidentally release the
lock if it's owned by another thread.
Also some small changes to usages of acquire() and release() in `rollback`
and `close` to follow this pattern to avoid extraneous or missed releases due
to exceptions/interrupts:
```
acquire()
try {
...
} finally {
release()
}
```
### Why are the changes needed?
There currently exists a race condition in RocksDB where the completion
listener may release a lock validly held by another thread, resulting in a
thread-unsafe condition. We have seen this bug in production.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests in RocksDB.scala, all other unit tests continue to pass as
expected.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Github Co-pilot 1.5.25.10
Closes #48697 from liviazhu-db/liviazhu-db/rocksdb-acquirelock-fix.
Authored-by: Livia Zhu <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/execution/streaming/state/RocksDB.scala | 85 ++++++---
.../execution/streaming/state/RocksDBSuite.scala | 190 ++++++++++++++++++++-
2 files changed, 249 insertions(+), 26 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index aeac5ea71a2e..544035e11785 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -743,17 +743,20 @@ class RocksDB(
*/
def rollback(): Unit = {
acquire(RollbackStore)
- numKeysOnWritingVersion = numKeysOnLoadedVersion
- loadedVersion = -1L
- lastCommitBasedStateStoreCkptId = None
- lastCommittedStateStoreCkptId = None
- loadedStateStoreCkptId = None
- sessionStateStoreCkptId = None
- changelogWriter.foreach(_.abort())
- // Make sure changelogWriter gets recreated next time.
- changelogWriter = None
- release(RollbackStore)
- logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
+ try {
+ numKeysOnWritingVersion = numKeysOnLoadedVersion
+ loadedVersion = -1L
+ lastCommitBasedStateStoreCkptId = None
+ lastCommittedStateStoreCkptId = None
+ loadedStateStoreCkptId = None
+ sessionStateStoreCkptId = None
+ changelogWriter.foreach(_.abort())
+ // Make sure changelogWriter gets recreated next time.
+ changelogWriter = None
+ logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
+ } finally {
+ release(RollbackStore)
+ }
}
def doMaintenance(): Unit = {
@@ -787,9 +790,9 @@ class RocksDB(
/** Release all resources */
def close(): Unit = {
+ // Acquire DB instance lock and release at the end to allow for
synchronized access
+ acquire(CloseStore)
try {
- // Acquire DB instance lock and release at the end to allow for
synchronized access
- acquire(CloseStore)
closeDB()
readOptions.close()
@@ -903,8 +906,8 @@ class RocksDB(
*/
def metricsOpt: Option[RocksDBMetrics] = {
var rocksDBMetricsOpt: Option[RocksDBMetrics] = None
+ acquire(ReportStoreMetrics)
try {
- acquire(ReportStoreMetrics)
rocksDBMetricsOpt = recordedMetrics
} catch {
case ex: Exception =>
@@ -956,7 +959,7 @@ class RocksDB(
acquiredThreadInfo = newAcquiredThreadInfo
// Add a listener to always release the lock when the task (if active)
completes
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
- _ => this.release(StoreTaskCompletionListener)
+ _ => this.release(StoreTaskCompletionListener,
Some(newAcquiredThreadInfo))
})
logInfo(log"RocksDB instance was acquired by ${MDC(LogKeys.THREAD,
acquiredThreadInfo)} " +
log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
@@ -965,16 +968,44 @@ class RocksDB(
/**
* Function to release RocksDB instance lock that allows for synchronized
access to the state
- * store instance
+ * store instance. Optionally provide a thread to check against, and release
only if provided
+ * thread is the one that acquired the lock.
*
* @param opType - operation type releasing the lock
+ * @param releaseForThreadOpt - optional thread to check against acquired
thread
*/
- private def release(opType: RocksDBOpType): Unit = acquireLock.synchronized {
+ private def release(
+ opType: RocksDBOpType,
+ releaseForThreadOpt: Option[AcquiredThreadInfo] = None): Unit =
acquireLock.synchronized {
if (acquiredThreadInfo != null) {
- logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD,
- acquiredThreadInfo)} " + log"for opType=${MDC(LogKeys.OP_TYPE,
opType.toString)}")
- acquiredThreadInfo = null
- acquireLock.notifyAll()
+ val release = releaseForThreadOpt match {
+ case Some(releaseForThread) if releaseForThread.threadRef.get.isEmpty
=>
+ logInfo(log"Thread reference is empty when attempting to release
for" +
+ log" opType=${MDC(LogKeys.OP_TYPE, opType.toString)}, ignoring
release." +
+ log" Lock is held by ${MDC(LogKeys.THREAD, acquiredThreadInfo)}")
+ false
+ // NOTE: we compare the entire acquiredThreadInfo object to ensure
that we are
+ // releasing not only for the right thread but the right task as well.
This is
+ // inconsistent with the logic for acquire which uses only the thread
ID, consider
+ // updating this in future.
+ case Some(releaseForThread) if acquiredThreadInfo != releaseForThread
=>
+ logInfo(log"Thread info for release" +
+ log" ${MDC(LogKeys.THREAD, releaseForThreadOpt.get)}" +
+ log" does not match the acquired thread when attempting to" +
+ log" release for opType=${MDC(LogKeys.OP_TYPE, opType.toString)},
ignoring release." +
+ log" Lock is held by ${MDC(LogKeys.THREAD, acquiredThreadInfo)}")
+ false
+ case _ => true
+ }
+
+ if (release) {
+ logInfo(log"RocksDB instance was released by " +
+ log"${MDC(LogKeys.THREAD, AcquiredThreadInfo())}. " +
+ log"acquiredThreadInfo: ${MDC(LogKeys.THREAD, acquiredThreadInfo)} "
+
+ log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
+ acquiredThreadInfo = null
+ acquireLock.notifyAll()
+ }
}
}
@@ -1001,6 +1032,11 @@ class RocksDB(
}
}
+ private[state] def getAcquiredThreadInfo(): Option[AcquiredThreadInfo] =
+ acquireLock.synchronized {
+ Option(acquiredThreadInfo).map(_.copy())
+ }
+
/** Create a native RocksDB logger that forwards native logs to log4j with
correct log levels. */
private def createLogger(): Logger = {
val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) {
@@ -1456,10 +1492,9 @@ object RocksDBNativeHistogram {
}
}
-case class AcquiredThreadInfo() {
- val threadRef: WeakReference[Thread] = new
WeakReference[Thread](Thread.currentThread())
- val tc: TaskContext = TaskContext.get()
-
+case class AcquiredThreadInfo(
+ threadRef: WeakReference[Thread] = new
WeakReference[Thread](Thread.currentThread()),
+ tc: TaskContext = TaskContext.get()) {
override def toString(): String = {
val taskStr = if (tc != null) {
val taskDetails =
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 8fde216c1441..3455e4f8387c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -34,7 +34,7 @@ import org.rocksdb.CompressionType
import org.scalactic.source.Position
import org.scalatest.Tag
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.streaming.{CreateAtomicTestManager,
FileSystemBasedCheckpointFileManager}
import
org.apache.spark.sql.execution.streaming.CheckpointFileManager.{CancellableFSDataOutputStream,
RenameBasedFSDataOutputStream}
@@ -2240,6 +2240,194 @@ class RocksDBSuite extends
AlsoTestWithChangelogCheckpointingEnabled with Shared
}
}
+ test("Rocks DB task completion listener does not double unlock
acquireThread") {
+ // This test verifies that a thread that locks then unlocks the db and then
+ // fires a completion listener (Thread 1) does not unlock the lock validly
+ // acquired by another thread (Thread 2).
+ //
+ // Timeline of this test (* means thread is active):
+ // STATE | MAIN | THREAD 1 | THREAD 2 |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 0. | wait for s3 | *load, commit | wait for s1 |
+ // | | *signal s1 | |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 1. | | wait for s2 | *load, signal s2 |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 2. | | *task complete | wait for s4 |
+ // | | *signal s3, END | |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 3. | *verify locked | | |
+ // | *signal s4 | | |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 4. | wait for s5 | | *commit |
+ // | | | *signal s5, END |
+ // ------| ---------------- | ---------------- | ---------------- |
+ // 5. | *close db, END | | |
+ //
+ // NOTE: state 4 and 5 are only for cleanup
+
+ // Create a custom ExecutionContext with 3 threads
+ implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
+ ThreadUtils.newDaemonFixedThreadPool(3, "pool-thread-executor"))
+ val stateLock = new Object()
+ var state = 0
+
+ withTempDir { dir =>
+ val remoteDir = dir.getCanonicalPath
+ val db = new RocksDB(
+ remoteDir,
+ conf = dbConf,
+ localRootDir = Utils.createTempDir(),
+ hadoopConf = new Configuration(),
+ loggingId = s"[Thread-${Thread.currentThread.getId}]",
+ useColumnFamilies = false
+ )
+ try {
+ Future { // THREAD 1
+ // Set thread 1's task context so that it is not a clone
+ // of the main thread's taskContext, which will end if the
+ // task is marked as complete
+ val taskContext = TaskContext.empty()
+ TaskContext.setTaskContext(taskContext)
+
+ stateLock.synchronized {
+ // -------------------- STATE 0 --------------------
+ // Simulate a task that loads and commits, db should be unlocked
after
+ db.load(0)
+ db.put("a", "1")
+ db.commit()
+ // Signal that we have entered state 1
+ state = 1
+ stateLock.notifyAll()
+
+ // -------------------- STATE 2 --------------------
+ // Wait until we have entered state 2 (thread 2 has loaded db and
acquired lock)
+ while (state != 2) {
+ stateLock.wait()
+ }
+
+ // thread 1's task context is marked as complete and signal
+ // that we have entered state 3
+ // At this point, thread 2 should still hold the DB lock.
+ taskContext.markTaskCompleted(None)
+ state = 3
+ stateLock.notifyAll()
+ }
+ }
+
+ Future { // THREAD 2
+ // Set thread 2's task context so that it is not a clone of thread
1's
+ // so it won't be marked as complete
+ val taskContext = TaskContext.empty()
+ TaskContext.setTaskContext(taskContext)
+
+ stateLock.synchronized {
+ // -------------------- STATE 1 --------------------
+ // Wait until we have entered state 1 (thread 1 finished loading
and committing)
+ while (state != 1) {
+ stateLock.wait()
+ }
+
+ // Load the db and signal that we have entered state 2
+ db.load(1)
+ assertAcquiredThreadIsCurrentThread(db)
+ state = 2
+ stateLock.notifyAll()
+
+ // -------------------- STATE 4 --------------------
+ // Wait until we have entered state 4 (thread 1 completed and
+ // main thread confirmed that lock is held)
+ while (state != 4) {
+ stateLock.wait()
+ }
+
+ // Ensure we still have the lock
+ assertAcquiredThreadIsCurrentThread(db)
+
+ // commit and signal that we have entered state 5
+ db.commit()
+ state = 5
+ stateLock.notifyAll()
+ }
+ }
+
+ // MAIN THREAD
+ stateLock.synchronized {
+ // -------------------- STATE 3 --------------------
+ // Wait until we have entered state 3 (thread 1 is complete)
+ while (state != 3) {
+ stateLock.wait()
+ }
+
+ // Verify that the lock is being held
+ val threadInfo = db.getAcquiredThreadInfo()
+ assert(threadInfo.nonEmpty, s"acquiredThreadInfo was None when it
should be Some")
+
+ // Signal that we have entered state 4 (thread 2 can now release
lock)
+ state = 4
+ stateLock.notifyAll()
+
+ // -------------------- STATE 5 --------------------
+ // Wait until we have entered state 5 (thread 2 has released lock)
+ // so that we can clean up
+ while (state != 5) {
+ stateLock.wait()
+ }
+ }
+ } finally {
+ db.close()
+ }
+ }
+ }
+
+ test("RocksDB task completion listener correctly releases for failed task") {
+ // This test verifies that a thread that locks the DB and then fails
+ // can rely on the completion listener to release the lock.
+
+ // Create a custom ExecutionContext with 1 thread
+ implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
+ ThreadUtils.newDaemonSingleThreadExecutor("single-thread-executor"))
+ val timeout = 5.seconds
+
+ withTempDir { dir =>
+ val remoteDir = dir.getCanonicalPath
+ withDB(remoteDir) { db =>
+ // Release the lock acquired by withDB
+ db.commit()
+
+ // New task that will load and then complete with failure
+ val fut = Future {
+ val taskContext = TaskContext.empty()
+ TaskContext.setTaskContext(taskContext)
+
+ db.load(0)
+ assertAcquiredThreadIsCurrentThread(db)
+
+ // Task completion listener should unlock
+ taskContext.markTaskCompleted(
+ Some(new SparkException("Task failure injection")))
+ }
+
+ ThreadUtils.awaitResult(fut, timeout)
+
+ // Assert that db is not locked
+ val threadInfo = db.getAcquiredThreadInfo()
+ assert(threadInfo.isEmpty, s"acquiredThreadInfo should be None but was
$threadInfo")
+ }
+ }
+ }
+
+ private def assertAcquiredThreadIsCurrentThread(db: RocksDB): Unit = {
+ val threadInfo = db.getAcquiredThreadInfo()
+ assert(threadInfo != None,
+ "acquired thread info should not be null after load")
+ val threadId = threadInfo.get.threadRef.get.get.getId
+ assert(
+ threadId == Thread.currentThread().getId,
+ s"acquired thread should be curent thread
${Thread.currentThread().getId} " +
+ s"after load but was $threadId")
+ }
+
private def dbConf = RocksDBConf(StateStoreConf(SQLConf.get.clone()))
def withDB[T](
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]