You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/11/18 04:30:35 UTC

[GitHub] [spark] mridulm commented on a change in pull request #34632: [SPARK-37356][CORE] Add fine grained locking to the BlockInfoManager

mridulm commented on a change in pull request #34632:
URL: https://github.com/apache/spark/pull/34632#discussion_r751886165



##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -122,24 +140,30 @@ private[storage] class BlockInfoManager extends Logging {
    * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed
    * by [[removeBlock()]].
    */
-  @GuardedBy("this")
-  private[this] val infos = new mutable.HashMap[BlockId, BlockInfo]
+  private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, BlockInfoWrapper]
+
+  /**
+   * Stripe used to control multi-threaded access to block information.
+   *
+   * We are using this instead of the synchronizing on the [[BlockInfo]] objects to avoid race
+   * conditions in the `lockNewBlockForWriting` method. When this method returns successfully is is

Review comment:
       nit: `is is` -> `it is`

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -166,6 +189,48 @@ private[storage] class BlockInfoManager extends Logging {
     Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER)
   }
 
+  /**
+   * Helper for lock acquisistion.
+   */
+  private def acquireLock(
+      blockId: BlockId,
+      blocking: Boolean)(
+      f: BlockInfo => Boolean): Option[BlockInfo] = {
+    var done = false
+    var result: Option[BlockInfo] = None
+    while(!done) {
+      val wrapper = blockInfoWrappers.get(blockId)
+      if (wrapper == null) {
+        done = true
+      } else {
+        wrapper.withLock { (info, condition) =>
+          if (f(info)) {
+            result = Some(info)
+            done = true
+          } else if (!blocking) {
+            done = true
+          } else {
+            condition.await()

Review comment:
       Note: There is a difference now - `synchronized` does not result in `InterruptedException`, while `await` can throw the exception.

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -319,18 +376,35 @@ private[storage] class BlockInfoManager extends Logging {
    */
   def lockNewBlockForWriting(
       blockId: BlockId,
-      newBlockInfo: BlockInfo): Boolean = synchronized {
+      newBlockInfo: BlockInfo): Boolean = {
     logTrace(s"Task $currentTaskAttemptId trying to put $blockId")
-    lockForReading(blockId) match {
-      case Some(info) =>
-        // Block already exists. This could happen if another thread races with us to compute
-        // the same block. In this case, just keep the read lock and return.
-        false
-      case None =>
-        // Block does not yet exist or is removed, so we are free to acquire the write lock
-        infos(blockId) = newBlockInfo
-        lockForWriting(blockId)
-        true
+    // Get the lock that will be associated with the to-be written block and lock it for the entire
+    // duration of this operation. This way we prevent race conditions when two threads try to write
+    // the same block at the same time.
+    val lock = locks.get(blockId)
+    lock.lock()
+    try {
+      val wrapper = new BlockInfoWrapper(newBlockInfo, lock)
+      while (true) {
+        val previous = blockInfoWrappers.putIfAbsent(blockId, wrapper)
+        if (previous == null) {
+          // New block lock it for writing.
+          val result = lockForWriting(blockId, blocking = false)
+          assert(result.isDefined)

Review comment:
       QQ: Can there be a race (and cause assertion failure) between `putIfAbsent` and `lockForWriting` ?
   Do we want to return `true` if `result.isDefined` - else retry loop ?

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -341,106 +415,103 @@ private[storage] class BlockInfoManager extends Logging {
    *
    * @return the ids of blocks whose pins were released
    */
-  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized {
+  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
     val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
 
-    val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
-    val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
-
-    for (blockId <- writeLocks) {
-      infos.get(blockId).foreach { info =>
+    val writeLocks = Option(writeLocksByTask.remove(taskAttemptId)).getOrElse(Collections.emptySet)
+    writeLocks.forEach { blockId =>
+      blockInfo(blockId) { (info, condition) =>
         assert(info.writerTask == taskAttemptId)
         info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
       }
       blocksWithReleasedLocks += blockId
     }
 
-    readLocks.entrySet().iterator().asScala.foreach { entry =>
+    val readLocks = Option(readLocksByTask.remove(taskAttemptId))
+      .getOrElse(ImmutableMultiset.of[BlockId])
+    readLocks.entrySet().forEach { entry =>

Review comment:
       I am still checking, but do we want to do both the `remove` before iterating over them ?

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
##########
@@ -341,106 +415,103 @@ private[storage] class BlockInfoManager extends Logging {
    *
    * @return the ids of blocks whose pins were released
    */
-  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized {
+  def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
     val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
 
-    val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
-    val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
-
-    for (blockId <- writeLocks) {
-      infos.get(blockId).foreach { info =>
+    val writeLocks = Option(writeLocksByTask.remove(taskAttemptId)).getOrElse(Collections.emptySet)
+    writeLocks.forEach { blockId =>
+      blockInfo(blockId) { (info, condition) =>
         assert(info.writerTask == taskAttemptId)
         info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
       }
       blocksWithReleasedLocks += blockId
     }
 
-    readLocks.entrySet().iterator().asScala.foreach { entry =>
+    val readLocks = Option(readLocksByTask.remove(taskAttemptId))
+      .getOrElse(ImmutableMultiset.of[BlockId])
+    readLocks.entrySet().forEach { entry =>
       val blockId = entry.getElement
       val lockCount = entry.getCount
       blocksWithReleasedLocks += blockId
-      get(blockId).foreach { info =>
+      blockInfo(blockId) { (info, condition) =>
         info.readerCount -= lockCount
         assert(info.readerCount >= 0)
+        condition.signalAll()
       }
     }
 
-    notifyAll()
-
     blocksWithReleasedLocks.toSeq
   }
 
   /** Returns the number of locks held by the given task.  Used only for testing. */
   private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = {
-    readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) +
-      writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0)
+    Option(readLocksByTask.get(taskAttemptId)).map(_.size()).getOrElse(0) +
+      Option(writeLocksByTask.get(taskAttemptId)).map(_.size).getOrElse(0)
   }
 
   /**
    * Returns the number of blocks tracked.
    */
-  def size: Int = synchronized {
-    infos.size
-  }
+  def size: Int = blockInfoWrappers.size
 
   /**
    * Return the number of map entries in this pin counter's internal data structures.
    * This is used in unit tests in order to detect memory leaks.
    */
-  private[storage] def getNumberOfMapEntries: Long = synchronized {
+  private[storage] def getNumberOfMapEntries: Long = {
     size +
       readLocksByTask.size +
-      readLocksByTask.map(_._2.size()).sum +
+      readLocksByTask.asScala.map(_._2.size()).sum +
       writeLocksByTask.size +
-      writeLocksByTask.map(_._2.size).sum
+      writeLocksByTask.asScala.map(_._2.size).sum
   }
 
   /**
    * Returns an iterator over a snapshot of all blocks' metadata. Note that the individual entries
    * in this iterator are mutable and thus may reflect blocks that are deleted while the iterator
    * is being traversed.
    */
-  def entries: Iterator[(BlockId, BlockInfo)] = synchronized {
-    infos.toArray.toIterator
+  def entries: Iterator[(BlockId, BlockInfo)] = {
+    blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey -> kv.getValue.info)
   }
 
   /**
    * Removes the given block and releases the write lock on it.
    *
    * This can only be called while holding a write lock on the given block.
    */
-  def removeBlock(blockId: BlockId): Unit = synchronized {
-    logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId")
-    infos.get(blockId) match {
-      case Some(blockInfo) =>
-        if (blockInfo.writerTask != currentTaskAttemptId) {
-          throw new IllegalStateException(
-            s"Task $currentTaskAttemptId called remove() on block $blockId without a write lock")
-        } else {
-          infos.remove(blockId)
-          blockInfo.readerCount = 0
-          blockInfo.writerTask = BlockInfo.NO_WRITER
-          writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
-        }
-      case None =>
-        throw new IllegalArgumentException(
-          s"Task $currentTaskAttemptId called remove() on non-existent block $blockId")
+  def removeBlock(blockId: BlockId): Unit = {
+    val taskAttemptId = currentTaskAttemptId
+    logTrace(s"Task $taskAttemptId trying to remove block $blockId")
+    blockInfo(blockId) { (info, condition) =>
+      if (info.writerTask != taskAttemptId) {
+        throw new IllegalStateException(
+          s"Task $taskAttemptId called remove() on block $blockId without a write lock")
+      } else {
+        blockInfoWrappers.remove(blockId)
+        info.readerCount = 0
+        info.writerTask = BlockInfo.NO_WRITER
+        writeLocksByTask.get(taskAttemptId).remove(blockId)
+      }
+      condition.signalAll()
     }
-    notifyAll()
   }
 
   /**
    * Delete all state. Called during shutdown.
    */
-  def clear(): Unit = synchronized {
-    infos.valuesIterator.foreach { blockInfo =>
-      blockInfo.readerCount = 0
-      blockInfo.writerTask = BlockInfo.NO_WRITER
+  def clear(): Unit = {
+    blockInfoWrappers.values().forEach { wrapper =>
+      wrapper.withLock { (info, condition) =>
+        info.readerCount = 0
+        info.writerTask = BlockInfo.NO_WRITER
+        condition.signalAll()
+      }

Review comment:
       nit: Given we are enhancing this codepath, do we want to do a `tryLock` for `clear` instead ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org