mridulm commented on a change in pull request #34632:
URL: https://github.com/apache/spark/pull/34632#discussion_r751886165



##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -122,24 +140,30 @@ private[storage] class BlockInfoManager extends Logging {
    * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed
    * by [[removeBlock()]].
    */
-  @GuardedBy("this")
-  private[this] val infos = new mutable.HashMap[BlockId, BlockInfo]
+  private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, 
BlockInfoWrapper]
+
+  /**
+   * Stripe used to control multi-threaded access to block information.
+   *
+   * We are using this instead of the synchronizing on the [[BlockInfo]] 
objects to avoid race
+   * conditions in the `lockNewBlockForWriting` method. When this method 
returns successfully is is

Review comment:
       nit: `is is` -> `it is`

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -166,6 +189,48 @@ private[storage] class BlockInfoManager extends Logging {
     
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER)
   }
 
+  /**
+   * Helper for lock acquisistion.
+   */
+  private def acquireLock(
+      blockId: BlockId,
+      blocking: Boolean)(
+      f: BlockInfo => Boolean): Option[BlockInfo] = {
+    var done = false
+    var result: Option[BlockInfo] = None
+    while(!done) {
+      val wrapper = blockInfoWrappers.get(blockId)
+      if (wrapper == null) {
+        done = true
+      } else {
+        wrapper.withLock { (info, condition) =>
+          if (f(info)) {
+            result = Some(info)
+            done = true
+          } else if (!blocking) {
+            done = true
+          } else {
+            condition.await()

Review comment:
       Note: There is a difference now - `synchronized` does not result in 
`InterruptedException`, while `await` can throw the exception.

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -319,18 +376,35 @@ private[storage] class BlockInfoManager extends Logging {
    */
   def lockNewBlockForWriting(
       blockId: BlockId,
-      newBlockInfo: BlockInfo): Boolean = synchronized {
+      newBlockInfo: BlockInfo): Boolean = {
     logTrace(s"Task $currentTaskAttemptId trying to put $blockId")
-    lockForReading(blockId) match {
-      case Some(info) =>
-        // Block already exists. This could happen if another thread races 
with us to compute
-        // the same block. In this case, just keep the read lock and return.
-        false
-      case None =>
-        // Block does not yet exist or is removed, so we are free to acquire 
the write lock
-        infos(blockId) = newBlockInfo
-        lockForWriting(blockId)
-        true
+    // Get the lock that will be associated with the to-be written block and 
lock it for the entire
+    // duration of this operation. This way we prevent race conditions when 
two threads try to write
+    // the same block at the same time.
+    val lock = locks.get(blockId)
+    lock.lock()
+    try {
+      val wrapper = new BlockInfoWrapper(newBlockInfo, lock)
+      while (true) {
+        val previous = blockInfoWrappers.putIfAbsent(blockId, wrapper)
+        if (previous == null) {
+          // New block lock it for writing.
+          val result = lockForWriting(blockId, blocking = false)
+          assert(result.isDefined)

Review comment:
       QQ: Can there be a race (and cause assertion failure) between 
`putIfAbsent` and `lockForWriting` ?
   Do we want to return `true` if `result.isDefined` - else retry loop ?

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -341,106 +415,103 @@ private[storage] class BlockInfoManager extends Logging 
{
    *
    * @return the ids of blocks whose pins were released
    */
-  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = 
synchronized {
+  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
     val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
 
-    val readLocks = 
readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
-    val writeLocks = 
writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
-
-    for (blockId <- writeLocks) {
-      infos.get(blockId).foreach { info =>
+    val writeLocks = 
Option(writeLocksByTask.remove(taskAttemptId)).getOrElse(Collections.emptySet)
+    writeLocks.forEach { blockId =>
+      blockInfo(blockId) { (info, condition) =>
         assert(info.writerTask == taskAttemptId)
         info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
       }
       blocksWithReleasedLocks += blockId
     }
 
-    readLocks.entrySet().iterator().asScala.foreach { entry =>
+    val readLocks = Option(readLocksByTask.remove(taskAttemptId))
+      .getOrElse(ImmutableMultiset.of[BlockId])
+    readLocks.entrySet().forEach { entry =>

Review comment:
       I am still checking, but do we want to do both the `remove` before 
iterating over them ?

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -341,106 +415,103 @@ private[storage] class BlockInfoManager extends Logging 
{
    *
    * @return the ids of blocks whose pins were released
    */
-  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = 
synchronized {
+  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
     val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
 
-    val readLocks = 
readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
-    val writeLocks = 
writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
-
-    for (blockId <- writeLocks) {
-      infos.get(blockId).foreach { info =>
+    val writeLocks = 
Option(writeLocksByTask.remove(taskAttemptId)).getOrElse(Collections.emptySet)
+    writeLocks.forEach { blockId =>
+      blockInfo(blockId) { (info, condition) =>
         assert(info.writerTask == taskAttemptId)
         info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
       }
       blocksWithReleasedLocks += blockId
     }
 
-    readLocks.entrySet().iterator().asScala.foreach { entry =>
+    val readLocks = Option(readLocksByTask.remove(taskAttemptId))
+      .getOrElse(ImmutableMultiset.of[BlockId])
+    readLocks.entrySet().forEach { entry =>
       val blockId = entry.getElement
       val lockCount = entry.getCount
       blocksWithReleasedLocks += blockId
-      get(blockId).foreach { info =>
+      blockInfo(blockId) { (info, condition) =>
         info.readerCount -= lockCount
         assert(info.readerCount >= 0)
+        condition.signalAll()
       }
     }
 
-    notifyAll()
-
     blocksWithReleasedLocks.toSeq
   }
 
   /** Returns the number of locks held by the given task.  Used only for 
testing. */
   private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = {
-    readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) +
-      writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0)
+    Option(readLocksByTask.get(taskAttemptId)).map(_.size()).getOrElse(0) +
+      Option(writeLocksByTask.get(taskAttemptId)).map(_.size).getOrElse(0)
   }
 
   /**
    * Returns the number of blocks tracked.
    */
-  def size: Int = synchronized {
-    infos.size
-  }
+  def size: Int = blockInfoWrappers.size
 
   /**
    * Return the number of map entries in this pin counter's internal data 
structures.
    * This is used in unit tests in order to detect memory leaks.
    */
-  private[storage] def getNumberOfMapEntries: Long = synchronized {
+  private[storage] def getNumberOfMapEntries: Long = {
     size +
       readLocksByTask.size +
-      readLocksByTask.map(_._2.size()).sum +
+      readLocksByTask.asScala.map(_._2.size()).sum +
       writeLocksByTask.size +
-      writeLocksByTask.map(_._2.size).sum
+      writeLocksByTask.asScala.map(_._2.size).sum
   }
 
   /**
    * Returns an iterator over a snapshot of all blocks' metadata. Note that 
the individual entries
    * in this iterator are mutable and thus may reflect blocks that are deleted 
while the iterator
    * is being traversed.
    */
-  def entries: Iterator[(BlockId, BlockInfo)] = synchronized {
-    infos.toArray.toIterator
+  def entries: Iterator[(BlockId, BlockInfo)] = {
+    blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey -> 
kv.getValue.info)
   }
 
   /**
    * Removes the given block and releases the write lock on it.
    *
    * This can only be called while holding a write lock on the given block.
    */
-  def removeBlock(blockId: BlockId): Unit = synchronized {
-    logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId")
-    infos.get(blockId) match {
-      case Some(blockInfo) =>
-        if (blockInfo.writerTask != currentTaskAttemptId) {
-          throw new IllegalStateException(
-            s"Task $currentTaskAttemptId called remove() on block $blockId 
without a write lock")
-        } else {
-          infos.remove(blockId)
-          blockInfo.readerCount = 0
-          blockInfo.writerTask = BlockInfo.NO_WRITER
-          writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
-        }
-      case None =>
-        throw new IllegalArgumentException(
-          s"Task $currentTaskAttemptId called remove() on non-existent block 
$blockId")
+  def removeBlock(blockId: BlockId): Unit = {
+    val taskAttemptId = currentTaskAttemptId
+    logTrace(s"Task $taskAttemptId trying to remove block $blockId")
+    blockInfo(blockId) { (info, condition) =>
+      if (info.writerTask != taskAttemptId) {
+        throw new IllegalStateException(
+          s"Task $taskAttemptId called remove() on block $blockId without a 
write lock")
+      } else {
+        blockInfoWrappers.remove(blockId)
+        info.readerCount = 0
+        info.writerTask = BlockInfo.NO_WRITER
+        writeLocksByTask.get(taskAttemptId).remove(blockId)
+      }
+      condition.signalAll()
     }
-    notifyAll()
   }
 
   /**
    * Delete all state. Called during shutdown.
    */
-  def clear(): Unit = synchronized {
-    infos.valuesIterator.foreach { blockInfo =>
-      blockInfo.readerCount = 0
-      blockInfo.writerTask = BlockInfo.NO_WRITER
+  def clear(): Unit = {
+    blockInfoWrappers.values().forEach { wrapper =>
+      wrapper.withLock { (info, condition) =>
+        info.readerCount = 0
+        info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
+      }

Review comment:
       nit: Given we are enhancing this codepath, do we want to do a `tryLock` 
for `clear` instead ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to