hvanhovell commented on a change in pull request #34632:
URL: https://github.com/apache/spark/pull/34632#discussion_r752144293
##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -341,106 +415,103 @@ private[storage] class BlockInfoManager extends Logging
{
*
* @return the ids of blocks whose pins were released
*/
- def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] =
synchronized {
+ def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
- val readLocks =
readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
- val writeLocks =
writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
-
- for (blockId <- writeLocks) {
- infos.get(blockId).foreach { info =>
+ val writeLocks =
Option(writeLocksByTask.remove(taskAttemptId)).getOrElse(Collections.emptySet)
+ writeLocks.forEach { blockId =>
+ blockInfo(blockId) { (info, condition) =>
assert(info.writerTask == taskAttemptId)
info.writerTask = BlockInfo.NO_WRITER
+ condition.signalAll()
}
blocksWithReleasedLocks += blockId
}
- readLocks.entrySet().iterator().asScala.foreach { entry =>
+ val readLocks = Option(readLocksByTask.remove(taskAttemptId))
+ .getOrElse(ImmutableMultiset.of[BlockId])
+ readLocks.entrySet().forEach { entry =>
val blockId = entry.getElement
val lockCount = entry.getCount
blocksWithReleasedLocks += blockId
- get(blockId).foreach { info =>
+ blockInfo(blockId) { (info, condition) =>
info.readerCount -= lockCount
assert(info.readerCount >= 0)
+ condition.signalAll()
}
}
- notifyAll()
-
blocksWithReleasedLocks.toSeq
}
/** Returns the number of locks held by the given task. Used only for
testing. */
private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = {
- readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) +
- writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0)
+ Option(readLocksByTask.get(taskAttemptId)).map(_.size()).getOrElse(0) +
+ Option(writeLocksByTask.get(taskAttemptId)).map(_.size).getOrElse(0)
}
/**
* Returns the number of blocks tracked.
*/
- def size: Int = synchronized {
- infos.size
- }
+ def size: Int = blockInfoWrappers.size
/**
* Return the number of map entries in this pin counter's internal data
structures.
* This is used in unit tests in order to detect memory leaks.
*/
- private[storage] def getNumberOfMapEntries: Long = synchronized {
+ private[storage] def getNumberOfMapEntries: Long = {
size +
readLocksByTask.size +
- readLocksByTask.map(_._2.size()).sum +
+ readLocksByTask.asScala.map(_._2.size()).sum +
writeLocksByTask.size +
- writeLocksByTask.map(_._2.size).sum
+ writeLocksByTask.asScala.map(_._2.size).sum
}
/**
* Returns an iterator over a snapshot of all blocks' metadata. Note that
the individual entries
* in this iterator are mutable and thus may reflect blocks that are deleted
while the iterator
* is being traversed.
*/
- def entries: Iterator[(BlockId, BlockInfo)] = synchronized {
- infos.toArray.toIterator
+ def entries: Iterator[(BlockId, BlockInfo)] = {
+ blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey ->
kv.getValue.info)
}
/**
* Removes the given block and releases the write lock on it.
*
* This can only be called while holding a write lock on the given block.
*/
- def removeBlock(blockId: BlockId): Unit = synchronized {
- logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId")
- infos.get(blockId) match {
- case Some(blockInfo) =>
- if (blockInfo.writerTask != currentTaskAttemptId) {
- throw new IllegalStateException(
- s"Task $currentTaskAttemptId called remove() on block $blockId
without a write lock")
- } else {
- infos.remove(blockId)
- blockInfo.readerCount = 0
- blockInfo.writerTask = BlockInfo.NO_WRITER
- writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
- }
- case None =>
- throw new IllegalArgumentException(
- s"Task $currentTaskAttemptId called remove() on non-existent block
$blockId")
+ def removeBlock(blockId: BlockId): Unit = {
+ val taskAttemptId = currentTaskAttemptId
+ logTrace(s"Task $taskAttemptId trying to remove block $blockId")
+ blockInfo(blockId) { (info, condition) =>
+ if (info.writerTask != taskAttemptId) {
+ throw new IllegalStateException(
+ s"Task $taskAttemptId called remove() on block $blockId without a
write lock")
+ } else {
+ blockInfoWrappers.remove(blockId)
+ info.readerCount = 0
+ info.writerTask = BlockInfo.NO_WRITER
+ writeLocksByTask.get(taskAttemptId).remove(blockId)
+ }
+ condition.signalAll()
}
- notifyAll()
}
/**
* Delete all state. Called during shutdown.
*/
- def clear(): Unit = synchronized {
- infos.valuesIterator.foreach { blockInfo =>
- blockInfo.readerCount = 0
- blockInfo.writerTask = BlockInfo.NO_WRITER
+ def clear(): Unit = {
+ blockInfoWrappers.values().forEach { wrapper =>
+ wrapper.withLock { (info, condition) =>
+ info.readerCount = 0
+ info.writerTask = BlockInfo.NO_WRITER
+ condition.signalAll()
+ }
Review comment:
Sure will update.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]