You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2015/07/02 22:59:58 UTC

spark git commit: [SPARK-7835] Refactor HeartbeatReceiverSuite for coverage + cleanup

Repository: spark
Updated Branches:
  refs/heads/master fcbcba66c -> cd2035507


[SPARK-7835] Refactor HeartbeatReceiverSuite for coverage + cleanup

The existing test suite has a lot of duplicate code and doesn't even cover the most fundamental feature of the HeartbeatReceiver, which is expiring hosts that have not responded in a while.

This introduces manual clocks in `HeartbeatReceiver` and makes it respond to heartbeats only for registered executors. A few internal messages are moved to `receiveAndReply` to increase determinism of the tests so we don't have to rely on flaky constructs like `eventually`.

Author: Andrew Or <an...@databricks.com>

Closes #7173 from andrewor14/heartbeat-receiver-tests and squashes the following commits:

4a903d6 [Andrew Or] Increase HeartReceiverSuite coverage and clean up


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd203550
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd203550
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd203550

Branch: refs/heads/master
Commit: cd2035507891a7f426f6f45902d3b5f4fdbe88cf
Parents: fcbcba6
Author: Andrew Or <an...@databricks.com>
Authored: Thu Jul 2 13:59:56 2015 -0700
Committer: Andrew Or <an...@databricks.com>
Committed: Thu Jul 2 13:59:56 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/HeartbeatReceiver.scala    |  89 +++++++---
 .../scala/org/apache/spark/SparkContext.scala   |   2 +-
 .../apache/spark/HeartbeatReceiverSuite.scala   | 161 ++++++++++++++-----
 3 files changed, 191 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cd203550/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 6909015..221b1da 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -24,8 +24,8 @@ import scala.collection.mutable
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
 import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.scheduler._
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
 
 /**
  * A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet
 
 private[spark] case object ExpireDeadHosts
 
+private case class ExecutorRegistered(executorId: String)
+
+private case class ExecutorRemoved(executorId: String)
+
 private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
 
 /**
  * Lives in the driver to receive heartbeats from executors..
  */
-private[spark] class HeartbeatReceiver(sc: SparkContext)
-  extends ThreadSafeRpcEndpoint with Logging {
+private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
+  extends ThreadSafeRpcEndpoint with SparkListener with Logging {
+
+  def this(sc: SparkContext) {
+    this(sc, new SystemClock)
+  }
+
+  sc.addSparkListener(this)
 
   override val rpcEnv: RpcEnv = sc.env.rpcEnv
 
@@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
   override def onStart(): Unit = {
     timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
       override def run(): Unit = Utils.tryLogNonFatalError {
-        Option(self).foreach(_.send(ExpireDeadHosts))
+        Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
       }
     }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
   }
 
-  override def receive: PartialFunction[Any, Unit] = {
-    case ExpireDeadHosts =>
-      expireDeadHosts()
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+
+    // Messages sent and received locally
+    case ExecutorRegistered(executorId) =>
+      executorLastSeen(executorId) = clock.getTimeMillis()
+      context.reply(true)
+    case ExecutorRemoved(executorId) =>
+      executorLastSeen.remove(executorId)
+      context.reply(true)
     case TaskSchedulerIsSet =>
       scheduler = sc.taskScheduler
-  }
+      context.reply(true)
+    case ExpireDeadHosts =>
+      expireDeadHosts()
+      context.reply(true)
 
-  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+    // Messages received from executors
     case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
       if (scheduler != null) {
-        executorLastSeen(executorId) = System.currentTimeMillis()
-        eventLoopThread.submit(new Runnable {
-          override def run(): Unit = Utils.tryLogNonFatalError {
-            val unknownExecutor = !scheduler.executorHeartbeatReceived(
-              executorId, taskMetrics, blockManagerId)
-            val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
-            context.reply(response)
-          }
-        })
+        if (executorLastSeen.contains(executorId)) {
+          executorLastSeen(executorId) = clock.getTimeMillis()
+          eventLoopThread.submit(new Runnable {
+            override def run(): Unit = Utils.tryLogNonFatalError {
+              val unknownExecutor = !scheduler.executorHeartbeatReceived(
+                executorId, taskMetrics, blockManagerId)
+              val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
+              context.reply(response)
+            }
+          })
+        } else {
+          // This may happen if we get an executor's in-flight heartbeat immediately
+          // after we just removed it. It's not really an error condition so we should
+          // not log warning here. Otherwise there may be a lot of noise especially if
+          // we explicitly remove executors (SPARK-4134).
+          logDebug(s"Received heartbeat from unknown executor $executorId")
+          context.reply(HeartbeatResponse(reregisterBlockManager = true))
+        }
       } else {
         // Because Executor will sleep several seconds before sending the first "Heartbeat", this
         // case rarely happens. However, if it really happens, log it and ask the executor to
@@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
       }
   }
 
+  /**
+   * If the heartbeat receiver is not stopped, notify it of executor registrations.
+   */
+  override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = {
+    Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId)))
+  }
+
+  /**
+   * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't
+   * log superfluous errors.
+   *
+   * Note that we must do this after the executor is actually removed to guard against the
+   * following race condition: if we remove an executor's metadata from our data structure
+   * prematurely, we may get an in-flight heartbeat from the executor before the executor is
+   * actually removed, in which case we will still mark the executor as a dead host later
+   * and expire it with loud error messages.
+   */
+  override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
+    Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId)))
+  }
+
   private def expireDeadHosts(): Unit = {
     logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
-    val now = System.currentTimeMillis()
+    val now = clock.getTimeMillis()
     for ((executorId, lastSeenMs) <- executorLastSeen) {
       if (now - lastSeenMs > executorTimeoutMs) {
         logWarning(s"Removing executor $executorId with no recent heartbeats: " +

http://git-wip-us.apache.org/repos/asf/spark/blob/cd203550/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8eed467..d2547ee 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -498,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
     _schedulerBackend = sched
     _taskScheduler = ts
     _dagScheduler = new DAGScheduler(this)
-    _heartbeatReceiver.send(TaskSchedulerIsSet)
+    _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)
 
     // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
     // constructor

http://git-wip-us.apache.org/repos/asf/spark/blob/cd203550/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 911b3bd..b31b091 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -17,64 +17,145 @@
 
 package org.apache.spark
 
-import scala.concurrent.duration._
 import scala.language.postfixOps
 
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.storage.BlockManagerId
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
 import org.mockito.Mockito.{mock, spy, verify, when}
 import org.mockito.Matchers
 import org.mockito.Matchers._
 
-import org.apache.spark.scheduler.TaskScheduler
-import org.apache.spark.util.RpcUtils
-import org.scalatest.concurrent.Eventually._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.ManualClock
 
-class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {
+class HeartbeatReceiverSuite
+  extends SparkFunSuite
+  with BeforeAndAfterEach
+  with PrivateMethodTester
+  with LocalSparkContext {
 
-  test("HeartbeatReceiver") {
+  private val executorId1 = "executor-1"
+  private val executorId2 = "executor-2"
+
+  // Shared state that must be reset before and after each test
+  private var scheduler: TaskScheduler = null
+  private var heartbeatReceiver: HeartbeatReceiver = null
+  private var heartbeatReceiverRef: RpcEndpointRef = null
+  private var heartbeatReceiverClock: ManualClock = null
+
+  override def beforeEach(): Unit = {
     sc = spy(new SparkContext("local[2]", "test"))
-    val scheduler = mock(classOf[TaskScheduler])
-    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+    scheduler = mock(classOf[TaskScheduler])
     when(sc.taskScheduler).thenReturn(scheduler)
+    heartbeatReceiverClock = new ManualClock
+    heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
+    heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
+    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+  }
 
-    val heartbeatReceiver = new HeartbeatReceiver(sc)
-    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
-    eventually(timeout(5 seconds), interval(5 millis)) {
-      assert(heartbeatReceiver.scheduler != null)
-    }
-    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+  override def afterEach(): Unit = {
+    resetSparkContext()
+    scheduler = null
+    heartbeatReceiver = null
+    heartbeatReceiverRef = null
+    heartbeatReceiverClock = null
+  }
 
-    val metrics = new TaskMetrics
-    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
-    val response = receiverRef.askWithRetry[HeartbeatResponse](
-      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+  test("task scheduler is set correctly") {
+    assert(heartbeatReceiver.scheduler === null)
+    heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+    assert(heartbeatReceiver.scheduler !== null)
+  }
 
-    verify(scheduler).executorHeartbeatReceived(
-      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
-    assert(false === response.reregisterBlockManager)
+  test("normal heartbeat") {
+    heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+    triggerHeartbeat(executorId1, executorShouldReregister = false)
+    triggerHeartbeat(executorId2, executorShouldReregister = false)
+    val trackedExecutors = executorLastSeen(heartbeatReceiver)
+    assert(trackedExecutors.size === 2)
+    assert(trackedExecutors.contains(executorId1))
+    assert(trackedExecutors.contains(executorId2))
   }
 
-  test("HeartbeatReceiver re-register") {
-    sc = spy(new SparkContext("local[2]", "test"))
-    val scheduler = mock(classOf[TaskScheduler])
-    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
-    when(sc.taskScheduler).thenReturn(scheduler)
+  test("reregister if scheduler is not ready yet") {
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+    // Task scheduler not set in HeartbeatReceiver
+    triggerHeartbeat(executorId1, executorShouldReregister = true)
+  }
 
-    val heartbeatReceiver = new HeartbeatReceiver(sc)
-    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
-    eventually(timeout(5 seconds), interval(5 millis)) {
-      assert(heartbeatReceiver.scheduler != null)
-    }
-    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+  test("reregister if heartbeat from unregistered executor") {
+    heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+    // Received heartbeat from unknown receiver, so we ask it to re-register
+    triggerHeartbeat(executorId1, executorShouldReregister = true)
+    assert(executorLastSeen(heartbeatReceiver).isEmpty)
+  }
+
+  test("reregister if heartbeat from removed executor") {
+    heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+    // Remove the second executor but not the first
+    heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy"))
+    // Now trigger the heartbeats
+    // A heartbeat from the second executor should require reregistering
+    triggerHeartbeat(executorId1, executorShouldReregister = false)
+    triggerHeartbeat(executorId2, executorShouldReregister = true)
+    val trackedExecutors = executorLastSeen(heartbeatReceiver)
+    assert(trackedExecutors.size === 1)
+    assert(trackedExecutors.contains(executorId1))
+    assert(!trackedExecutors.contains(executorId2))
+  }
 
+  test("expire dead hosts") {
+    val executorTimeout = executorTimeoutMs(heartbeatReceiver)
+    heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+    heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+    triggerHeartbeat(executorId1, executorShouldReregister = false)
+    triggerHeartbeat(executorId2, executorShouldReregister = false)
+    // Advance the clock and only trigger a heartbeat for the first executor
+    heartbeatReceiverClock.advance(executorTimeout / 2)
+    triggerHeartbeat(executorId1, executorShouldReregister = false)
+    heartbeatReceiverClock.advance(executorTimeout)
+    heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
+    // Only the second executor should be expired as a dead host
+    verify(scheduler).executorLost(Matchers.eq(executorId2), any())
+    val trackedExecutors = executorLastSeen(heartbeatReceiver)
+    assert(trackedExecutors.size === 1)
+    assert(trackedExecutors.contains(executorId1))
+    assert(!trackedExecutors.contains(executorId2))
+  }
+
+  /** Manually send a heartbeat and return the response. */
+  private def triggerHeartbeat(
+      executorId: String,
+      executorShouldReregister: Boolean): Unit = {
     val metrics = new TaskMetrics
-    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
-    val response = receiverRef.askWithRetry[HeartbeatResponse](
-      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+    val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
+    val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
+      Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
+    if (executorShouldReregister) {
+      assert(response.reregisterBlockManager)
+    } else {
+      assert(!response.reregisterBlockManager)
+      // Additionally verify that the scheduler callback is called with the correct parameters
+      verify(scheduler).executorHeartbeatReceived(
+        Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+    }
+  }
 
-    verify(scheduler).executorHeartbeatReceived(
-      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
-    assert(true === response.reregisterBlockManager)
+  // Helper methods to access private fields in HeartbeatReceiver
+  private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
+  private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
+  private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
+    receiver invokePrivate _executorLastSeen()
+  }
+  private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
+    receiver invokePrivate _executorTimeoutMs()
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org