You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2019/09/03 21:09:32 UTC

[spark] branch master updated: [SPARK-3137][CORE] Replace the global TorrentBroadcast lock with fine grained KeyLock

This is an automated email from the ASF dual-hosted git repository.

zsxwing pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8980093  [SPARK-3137][CORE] Replace the global TorrentBroadcast lock with fine grained KeyLock
8980093 is described below

commit 89800931aa8b565335e45e1d26ff60402e46c534
Author: Shixiong Zhu <zs...@gmail.com>
AuthorDate: Tue Sep 3 14:09:07 2019 -0700

    [SPARK-3137][CORE] Replace the global TorrentBroadcast lock with fine grained KeyLock
    
    ### What changes were proposed in this pull request?
    
    This PR provides a new lock mechanism `KeyLock` to lock  with a given key. Also use this new lock in `TorrentBroadcast` to avoid blocking tasks from fetching different broadcast values.
    
    ### Why are the changes needed?
    
    `TorrentBroadcast.readObject` uses a global lock so only one task can be fetching the blocks at the same time. This is not optimal if we are running multiple stages concurrently because they should be able to independently fetch their own blocks.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #25612 from zsxwing/SPARK-3137.
    
    Authored-by: Shixiong Zhu <zs...@gmail.com>
    Signed-off-by: Shixiong Zhu <zs...@gmail.com>
---
 .../apache/spark/broadcast/BroadcastManager.scala  |   9 +-
 .../apache/spark/broadcast/TorrentBroadcast.scala  |  20 ++--
 .../main/scala/org/apache/spark/util/KeyLock.scala |  69 ++++++++++++
 .../scala/org/apache/spark/util/KeyLockSuite.scala | 118 +++++++++++++++++++++
 4 files changed, 207 insertions(+), 9 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index ed45043..9fa4745 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.broadcast
 
+import java.util.Collections
 import java.util.concurrent.atomic.AtomicLong
 
 import scala.reflect.ClassTag
@@ -55,9 +56,11 @@ private[spark] class BroadcastManager(
 
   private val nextBroadcastId = new AtomicLong(0)
 
-  private[broadcast] val cachedValues = {
-    new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
-  }
+  private[broadcast] val cachedValues =
+    Collections.synchronizedMap(
+      new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
+        .asInstanceOf[java.util.Map[Any, Any]]
+    )
 
   def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
     val bid = nextBroadcastId.getAndIncrement()
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index f416be8..1379314 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -31,7 +31,7 @@ import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{KeyLock, Utils}
 import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
 
 /**
@@ -167,7 +167,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
       bm.getLocalBytes(pieceId) match {
         case Some(block) =>
           blocks(pid) = block
-          releaseLock(pieceId)
+          releaseBlockManagerLock(pieceId)
         case None =>
           bm.getRemoteBytes(pieceId) match {
             case Some(b) =>
@@ -215,8 +215,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
   }
 
   private def readBroadcastBlock(): T = Utils.tryOrIOException {
-    val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
-    broadcastCache.synchronized {
+    TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
+      // As we only lock based on `broadcastId`, whenever using `broadcastCache`, we should only
+      // touch `broadcastId`.
+      val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
 
       Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
         setConf(SparkEnv.get.conf)
@@ -225,7 +227,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
           case Some(blockResult) =>
             if (blockResult.data.hasNext) {
               val x = blockResult.data.next().asInstanceOf[T]
-              releaseLock(broadcastId)
+              releaseBlockManagerLock(broadcastId)
 
               if (x != null) {
                 broadcastCache.put(broadcastId, x)
@@ -270,7 +272,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
    * If running in a task, register the given block's locks for release upon task completion.
    * Otherwise, if not running in a task then immediately release the lock.
    */
-  private def releaseLock(blockId: BlockId): Unit = {
+  private def releaseBlockManagerLock(blockId: BlockId): Unit = {
     val blockManager = SparkEnv.get.blockManager
     Option(TaskContext.get()) match {
       case Some(taskContext) =>
@@ -290,6 +292,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
 
 private object TorrentBroadcast extends Logging {
 
+  /**
+   * A [[KeyLock]] whose key is [[BroadcastBlockId]] to ensure there is only one thread fetching
+   * the same [[TorrentBroadcast]] block.
+   */
+  private val torrentBroadcastLock = new KeyLock[BroadcastBlockId]
+
   def blockifyObject[T: ClassTag](
       obj: T,
       blockSize: Int,
diff --git a/core/src/main/scala/org/apache/spark/util/KeyLock.scala b/core/src/main/scala/org/apache/spark/util/KeyLock.scala
new file mode 100644
index 0000000..f9a96cd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/KeyLock.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.ConcurrentHashMap
+
+/**
+ * A special locking mechanism to provide locking with a given key. By providing the same key
+ * (identity is tested using the `equals` method), we ensure there is only one `func` running at
+ * the same time.
+ *
+ * @tparam K the type of key to identify a lock. This type must implement `equals` and `hashCode`
+ *           correctly as it will be the key type of an internal Map.
+ */
+private[spark] class KeyLock[K] {
+
+  private val lockMap = new ConcurrentHashMap[K, AnyRef]()
+
+  private def acquireLock(key: K): Unit = {
+    while (true) {
+      val lock = lockMap.putIfAbsent(key, new Object)
+      if (lock == null) return
+      lock.synchronized {
+        while (lockMap.get(key) eq lock) {
+          lock.wait()
+        }
+      }
+    }
+  }
+
+  private def releaseLock(key: K): Unit = {
+    val lock = lockMap.remove(key)
+    lock.synchronized {
+      lock.notifyAll()
+    }
+  }
+
+  /**
+   * Run `func` under a lock identified by the given key. Multiple calls with the same key
+   * (identity is tested using the `equals` method) will be locked properly to ensure there is only
+   * one `func` running at the same time.
+   */
+  def withLock[T](key: K)(func: => T): T = {
+    if (key == null) {
+      throw new NullPointerException("key must not be null")
+    }
+    acquireLock(key)
+    try {
+      func
+    } finally {
+      releaseLock(key)
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala b/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala
new file mode 100644
index 0000000..2169a0e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.duration._
+
+import org.scalatest.concurrent.{ThreadSignaler, TimeLimits}
+
+import org.apache.spark.SparkFunSuite
+
+class KeyLockSuite extends SparkFunSuite with TimeLimits {
+
+  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
+  private implicit val defaultSignaler = ThreadSignaler
+
+  private val foreverMs = 60 * 1000L
+
+  test("The same key should wait when its lock is held") {
+    val keyLock = new KeyLock[Object]
+    val numThreads = 10
+    // Create different objects that are equal
+    val keys = List.fill(numThreads)(List(1))
+    require(keys.tail.forall(_ ne keys.head) && keys.tail.forall(_ == keys.head))
+
+    // A latch to make `withLock` be called almost at the same time
+    val latch = new CountDownLatch(1)
+    // Track how many threads get the lock at the same time
+    val numThreadsHoldingLock = new AtomicInteger(0)
+    // Track how many functions get called
+    val numFuncCalled = new AtomicInteger(0)
+    @volatile var e: Throwable = null
+    val threads = (0 until numThreads).map { i =>
+      new Thread() {
+        override def run(): Unit = try {
+          latch.await(foreverMs, TimeUnit.MILLISECONDS)
+          keyLock.withLock(keys(i)) {
+            var cur = numThreadsHoldingLock.get()
+            if (cur != 0) {
+              e = new AssertionError(s"numThreadsHoldingLock is not 0: $cur")
+            }
+            cur = numThreadsHoldingLock.incrementAndGet()
+            if (cur != 1) {
+              e = new AssertionError(s"numThreadsHoldingLock is not 1: $cur")
+            }
+            cur = numThreadsHoldingLock.decrementAndGet()
+            if (cur != 0) {
+              e = new AssertionError(s"numThreadsHoldingLock is not 0: $cur")
+            }
+            numFuncCalled.incrementAndGet()
+          }
+        }
+      }
+    }
+    threads.foreach(_.start())
+    latch.countDown()
+    threads.foreach(_.join())
+    if (e != null) {
+      throw e
+    }
+    assert(numFuncCalled.get === numThreads)
+  }
+
+  test("A different key should not be locked") {
+    val keyLock = new KeyLock[Object]
+    val k1 = new Object
+    val k2 = new Object
+
+    // Start a thread to hold the lock for `k1` forever
+    val latch = new CountDownLatch(1)
+    val t = new Thread() {
+      override def run(): Unit = try {
+        keyLock.withLock(k1) {
+          latch.countDown()
+          Thread.sleep(foreverMs)
+        }
+      } catch {
+        case _: InterruptedException => // Ignore it as it's the exit signal
+      }
+    }
+    t.start()
+    try {
+      // Wait until the thread gets the lock for `k1`
+      if (!latch.await(foreverMs, TimeUnit.MILLISECONDS)) {
+        throw new TimeoutException("thread didn't get the lock")
+      }
+
+      var funcCalled = false
+      // Verify we can acquire the lock for `k2` and call `func`
+      failAfter(foreverMs.millis) {
+        keyLock.withLock(k2) {
+          funcCalled = true
+        }
+      }
+      assert(funcCalled, "func is not called")
+    } finally {
+      t.interrupt()
+      t.join()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org