You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2020/08/26 22:17:38 UTC

[spark] branch master updated: [SPARK-32643][CORE][K8S] Consolidate state decommissioning in the TaskSchedulerImpl realm

This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b786f31  [SPARK-32643][CORE][K8S] Consolidate state decommissioning in the TaskSchedulerImpl realm
b786f31 is described below

commit b786f31a42180523b0baa8113e26b2ddee445498
Author: Devesh Agrawal <de...@gmail.com>
AuthorDate: Wed Aug 26 15:16:47 2020 -0700

    [SPARK-32643][CORE][K8S] Consolidate state decommissioning in the TaskSchedulerImpl realm
    
    ### What changes were proposed in this pull request?
    The decommissioning state is a bit fragment across two places in the TaskSchedulerImpl:
    
    https://github.com/apache/spark/pull/29014/ stored the incoming decommission info messages in TaskSchedulerImpl.executorsPendingDecommission.
    While https://github.com/apache/spark/pull/28619/ was storing just the executor end time in the map TaskSetManager.tidToExecutorKillTimeMapping (which in turn is contained in TaskSchedulerImpl).
    While the two states are not really overlapping, it's a bit of a code hygiene concern to save this state in two places.
    
    With https://github.com/apache/spark/pull/29422, TaskSchedulerImpl is emerging as the place where all decommissioning book keeping is kept within the driver. So consolidate the information in _tidToExecutorKillTimeMapping_ into _executorsPendingDecommission_.
    
    However, in order to do so, we need to walk away from keeping the raw ExecutorDecommissionInfo messages and instead keep another class ExecutorDecommissionState. This decoupling will allow the RPC message class ExecutorDecommissionInfo to evolve independently from the book keeping ExecutorDecommissionState.
    
    ### Why are the changes needed?
    
    This is just a code cleanup. These two features were added independently and its time to consolidate their state for good hygiene.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #29452 from agrawaldevesh/consolidate_decom_state.
    
    Authored-by: Devesh Agrawal <de...@gmail.com>
    Signed-off-by: Holden Karau <hk...@apple.com>
---
 .../apache/spark/ExecutorAllocationClient.scala    |   2 +-
 .../org/apache/spark/scheduler/DAGScheduler.scala  |   2 +-
 .../spark/scheduler/ExecutorDecommissionInfo.scala |  14 ++-
 .../org/apache/spark/scheduler/TaskScheduler.scala |   2 +-
 .../apache/spark/scheduler/TaskSchedulerImpl.scala |  52 ++++++----
 .../apache/spark/scheduler/TaskSetManager.scala    |  34 +++----
 .../apache/spark/scheduler/DAGSchedulerSuite.scala |   8 +-
 .../scheduler/ExternalClusterManagerSuite.scala    |   4 +-
 .../spark/scheduler/TaskSchedulerImplSuite.scala   | 112 ++++++++++++++-------
 .../spark/scheduler/TaskSetManagerSuite.scala      |  33 +++---
 10 files changed, 167 insertions(+), 96 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index 079340a..ce47f3f 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -88,7 +88,7 @@ private[spark] trait ExecutorAllocationClient {
    * Default implementation delegates to kill, scheduler must override
    * if it supports graceful decommissioning.
    *
-   * @param executorsAndDecominfo identifiers of executors & decom info.
+   * @param executorsAndDecomInfo identifiers of executors & decom info.
    * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down
    *                                 after these executors have been decommissioned.
    * @return the ids of the executors acknowledged by the cluster manager to be removed.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ae0387e..18cd241 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1825,7 +1825,7 @@ private[spark] class DAGScheduler(
           if (bmAddress != null) {
             val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
             val isHostDecommissioned = taskScheduler
-              .getExecutorDecommissionInfo(bmAddress.executorId)
+              .getExecutorDecommissionState(bmAddress.executorId)
               .exists(_.isHostDecommissioned)
 
             // Shuffle output of all executors on host `bmAddress.host` may be lost if:
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorDecommissionInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorDecommissionInfo.scala
index a82b5d3..48ae879 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorDecommissionInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorDecommissionInfo.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler
 
 /**
- * Provides more detail when an executor is being decommissioned.
+ * Message providing more detail when an executor is being decommissioned.
  * @param message Human readable reason for why the decommissioning is happening.
  * @param isHostDecommissioned Whether the host (aka the `node` or `worker` in other places) is
  *                             being decommissioned too. Used to infer if the shuffle data might
@@ -26,3 +26,15 @@ package org.apache.spark.scheduler
  */
 private[spark]
 case class ExecutorDecommissionInfo(message: String, isHostDecommissioned: Boolean)
+
+/**
+ * State related to decommissioning that is kept by the TaskSchedulerImpl. This state is derived
+ * from the info message above but it is kept distinct to allow the state to evolve independently
+ * from the message.
+ */
+case class ExecutorDecommissionState(
+    // Timestamp the decommissioning commenced as per the Driver's clock,
+    // to estimate when the executor might eventually be lost if EXECUTOR_DECOMMISSION_KILL_INTERVAL
+    // is configured.
+    startTime: Long,
+    isHostDecommissioned: Boolean)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1101d06..0fa80bb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -106,7 +106,7 @@ private[spark] trait TaskScheduler {
   /**
    * If an executor is decommissioned, return its corresponding decommission info
    */
-  def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo]
+  def getExecutorDecommissionState(executorId: String): Option[ExecutorDecommissionState]
 
   /**
    * Process a lost executor
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index db6797c..d446638 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -141,7 +141,7 @@ private[spark] class TaskSchedulerImpl(
 
   // We add executors here when we first get decommission notification for them. Executors can
   // continue to run even after being asked to decommission, but they will eventually exit.
-  val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
+  val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionState]
 
   // When they exit and we know of that via heartbeat failure, we will add them to this cache.
   // This cache is consulted to know if a fetch failure is because a source executor was
@@ -152,7 +152,7 @@ private[spark] class TaskSchedulerImpl(
     .ticker(new Ticker{
       override def read(): Long = TimeUnit.MILLISECONDS.toNanos(clock.getTimeMillis())
     })
-    .build[String, ExecutorDecommissionInfo]()
+    .build[String, ExecutorDecommissionState]()
     .asMap()
 
   def runningTasksByExecutors: Map[String, Int] = synchronized {
@@ -293,7 +293,7 @@ private[spark] class TaskSchedulerImpl(
   private[scheduler] def createTaskSetManager(
       taskSet: TaskSet,
       maxTaskFailures: Int): TaskSetManager = {
-    new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
+    new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock)
   }
 
   override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
@@ -922,22 +922,36 @@ private[spark] class TaskSchedulerImpl(
     synchronized {
       // Don't bother noting decommissioning for executors that we don't know about
       if (executorIdToHost.contains(executorId)) {
-        // The scheduler can get multiple decommission updates from multiple sources,
-        // and some of those can have isHostDecommissioned false. We merge them such that
-        // if we heard isHostDecommissioned ever true, then we keep that one since it is
-        // most likely coming from the cluster manager and thus authoritative
-        val oldDecomInfo = executorsPendingDecommission.get(executorId)
-        if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
-          executorsPendingDecommission(executorId) = decommissionInfo
+        val oldDecomStateOpt = executorsPendingDecommission.get(executorId)
+        val newDecomState = if (oldDecomStateOpt.isEmpty) {
+          // This is the first time we are hearing of decommissioning this executor,
+          // so create a brand new state.
+          ExecutorDecommissionState(
+            clock.getTimeMillis(),
+            decommissionInfo.isHostDecommissioned)
+        } else {
+          val oldDecomState = oldDecomStateOpt.get
+          if (!oldDecomState.isHostDecommissioned && decommissionInfo.isHostDecommissioned) {
+            // Only the cluster manager is allowed to send decommission messages with
+            // isHostDecommissioned set. So the new decommissionInfo is from the cluster
+            // manager and is thus authoritative. Flip isHostDecommissioned to true but keep the old
+            // decommission start time.
+            ExecutorDecommissionState(
+              oldDecomState.startTime,
+              isHostDecommissioned = true)
+          } else {
+            oldDecomState
+          }
         }
+        executorsPendingDecommission(executorId) = newDecomState
       }
     }
     rootPool.executorDecommission(executorId)
     backend.reviveOffers()
   }
 
-  override def getExecutorDecommissionInfo(executorId: String)
-    : Option[ExecutorDecommissionInfo] = synchronized {
+  override def getExecutorDecommissionState(executorId: String)
+    : Option[ExecutorDecommissionState] = synchronized {
     executorsPendingDecommission
       .get(executorId)
       .orElse(Option(decommissionedExecutorsRemoved.get(executorId)))
@@ -948,14 +962,14 @@ private[spark] class TaskSchedulerImpl(
     val reason = givenReason match {
       // Handle executor process loss due to decommissioning
       case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) =>
-        val executorDecommissionInfo = getExecutorDecommissionInfo(executorId)
+        val executorDecommissionState = getExecutorDecommissionState(executorId)
         ExecutorProcessLost(
           message,
           // Also mark the worker lost if we know that the host was decommissioned
-          origWorkerLost || executorDecommissionInfo.exists(_.isHostDecommissioned),
+          origWorkerLost || executorDecommissionState.exists(_.isHostDecommissioned),
           // Executor loss is certainly not caused by app if we knew that this executor is being
           // decommissioned
-          causedByApp = executorDecommissionInfo.isEmpty && origCausedByApp)
+          causedByApp = executorDecommissionState.isEmpty && origCausedByApp)
       case e => e
     }
 
@@ -1047,8 +1061,8 @@ private[spark] class TaskSchedulerImpl(
     }
 
 
-    val decomInfo = executorsPendingDecommission.remove(executorId)
-    decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))
+    val decomState = executorsPendingDecommission.remove(executorId)
+    decomState.foreach(decommissionedExecutorsRemoved.put(executorId, _))
 
     if (reason != LossReasonPending) {
       executorIdToHost -= executorId
@@ -1085,12 +1099,12 @@ private[spark] class TaskSchedulerImpl(
 
   // exposed for test
   protected final def isExecutorDecommissioned(execId: String): Boolean =
-    getExecutorDecommissionInfo(execId).nonEmpty
+    getExecutorDecommissionState(execId).isDefined
 
   // exposed for test
   protected final def isHostDecommissioned(host: String): Boolean = {
     hostToExecutors.get(host).exists { executors =>
-      executors.exists(e => getExecutorDecommissionInfo(e).exists(_.isHostDecommissioned))
+      executors.exists(e => getExecutorDecommissionState(e).exists(_.isHostDecommissioned))
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 3a779d1..ff03876 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -169,7 +169,6 @@ private[spark] class TaskSetManager(
 
   // Task index, start and finish time for each task attempt (indexed by task ID)
   private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
-  private[scheduler] val tidToExecutorKillTimeMapping = new HashMap[Long, Long]
 
   // Use a MedianHeap to record durations of successful tasks so we know when to launch
   // speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -943,7 +942,6 @@ private[spark] class TaskSetManager(
 
   /** If the given task ID is in the set of running tasks, removes it. */
   def removeRunningTask(tid: Long): Unit = {
-    tidToExecutorKillTimeMapping.remove(tid)
     if (runningTasksSet.remove(tid) && parent != null) {
       parent.decreaseRunningTasks(1)
     }
@@ -1054,15 +1052,21 @@ private[spark] class TaskSetManager(
       logDebug("Task length threshold for speculation: " + threshold)
       for (tid <- runningTasksSet) {
         var speculated = checkAndSubmitSpeculatableTask(tid, time, threshold)
-        if (!speculated && tidToExecutorKillTimeMapping.contains(tid)) {
-          // Check whether this task will finish before the exectorKillTime assuming
-          // it will take medianDuration overall. If this task cannot finish within
-          // executorKillInterval, then this task is a candidate for speculation
-          val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
-          val canExceedDeadline = tidToExecutorKillTimeMapping(tid) <
-            taskEndTimeBasedOnMedianDuration
-          if (canExceedDeadline) {
-            speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
+        if (!speculated && executorDecommissionKillInterval.isDefined) {
+          val taskInfo = taskInfos(tid)
+          val decomState = sched.getExecutorDecommissionState(taskInfo.executorId)
+          if (decomState.isDefined) {
+            // Check if this task might finish after this executor is decommissioned.
+            // We estimate the task's finish time by using the median task duration.
+            // Whereas the time when the executor might be decommissioned is estimated using the
+            // config executorDecommissionKillInterval. If the task is going to finish after
+            // decommissioning, then we will eagerly speculate the task.
+            val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
+            val executorDecomTime = decomState.get.startTime + executorDecommissionKillInterval.get
+            val canExceedDeadline = executorDecomTime < taskEndTimeBasedOnMedianDuration
+            if (canExceedDeadline) {
+              speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
+            }
           }
         }
         foundTasks |= speculated
@@ -1123,14 +1127,6 @@ private[spark] class TaskSetManager(
 
   def executorDecommission(execId: String): Unit = {
     recomputeLocality()
-    if (speculationEnabled) {
-      executorDecommissionKillInterval.foreach { interval =>
-        val executorKillTime = clock.getTimeMillis() + interval
-        runningTasksSet.filter(taskInfos(_).executorId == execId).foreach { tid =>
-          tidToExecutorKillTimeMapping(tid) = executorKillTime
-        }
-      }
-    }
   }
 
   def recomputeLocality(): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index c829006..a7f8aff 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -178,8 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
     override def executorDecommission(
       executorId: String,
       decommissionInfo: ExecutorDecommissionInfo): Unit = {}
-    override def getExecutorDecommissionInfo(
-      executorId: String): Option[ExecutorDecommissionInfo] = None
+    override def getExecutorDecommissionState(
+      executorId: String): Option[ExecutorDecommissionState] = None
   }
 
   /**
@@ -787,8 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
       override def executorDecommission(
         executorId: String,
         decommissionInfo: ExecutorDecommissionInfo): Unit = {}
-      override def getExecutorDecommissionInfo(
-        executorId: String): Option[ExecutorDecommissionInfo] = None
+      override def getExecutorDecommissionState(
+        executorId: String): Option[ExecutorDecommissionState] = None
     }
     val noKillScheduler = new DAGScheduler(
       sc,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 07d8867..08191d0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -101,6 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler {
   override def executorDecommission(
     executorId: String,
     decommissionInfo: ExecutorDecommissionInfo): Unit = {}
-  override def getExecutorDecommissionInfo(
-    executorId: String): Option[ExecutorDecommissionInfo] = None
+  override def getExecutorDecommissionState(
+    executorId: String): Option[ExecutorDecommissionState] = None
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 1c8d799..26c9d91 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -88,15 +88,10 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
   }
 
   def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = {
-    setupSchedulerWithMasterAndClock(master, new SystemClock, confs: _*)
-  }
-
-  def setupSchedulerWithMasterAndClock(master: String, clock: Clock, confs: (String, String)*):
-  TaskSchedulerImpl = {
     val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite")
     confs.foreach { case (k, v) => conf.set(k, v) }
     sc = new SparkContext(conf)
-    taskScheduler = new TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
+    taskScheduler = new TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES))
     setupHelper()
   }
 
@@ -1834,66 +1829,111 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     assert(2 == taskDescriptions.head.resources(GPU).addresses.size)
   }
 
-  private def setupSchedulerForDecommissionTests(clock: Clock): TaskSchedulerImpl = {
-    val taskScheduler = setupSchedulerWithMasterAndClock(
-      s"local[2]",
-      clock,
-      config.CPUS_PER_TASK.key -> 1.toString)
-    taskScheduler.submitTasks(FakeTask.createTaskSet(2))
-    val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1),
-      WorkerOffer("executor1", "host1", 1))
-    val taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
-    assert(taskDescriptions.map(_.executorId).sorted === Seq("executor0", "executor1"))
+  private def setupSchedulerForDecommissionTests(clock: Clock, numTasks: Int): TaskSchedulerImpl = {
+    // one task per host
+    val numHosts = numTasks
+    val conf = new SparkConf()
+      .setMaster(s"local[$numHosts]")
+      .setAppName("TaskSchedulerImplSuite")
+      .set(config.CPUS_PER_TASK.key, "1")
+    sc = new SparkContext(conf)
+    val maxTaskFailures = sc.conf.get(config.TASK_MAX_FAILURES)
+    taskScheduler = new TaskSchedulerImpl(sc, maxTaskFailures, clock = clock) {
+      override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = {
+        val tsm = super.createTaskSetManager(taskSet, maxFailures)
+        // we need to create a spied tsm so that we can see the copies running
+        val tsmSpy = spy(tsm)
+        stageToMockTaskSetManager(taskSet.stageId) = tsmSpy
+        tsmSpy
+      }
+    }
+    setupHelper()
+    // Spawn the tasks on different executors/hosts
+    taskScheduler.submitTasks(FakeTask.createTaskSet(numTasks))
+    for (i <- 0 until numTasks) {
+      val executorId = s"executor$i"
+      val taskDescriptions = taskScheduler.resourceOffers(IndexedSeq(WorkerOffer(
+         executorId, s"host$i", 1))).flatten
+      assert(taskDescriptions.size === 1)
+      assert(taskDescriptions(0).executorId == executorId)
+      assert(taskDescriptions(0).index === i)
+    }
     taskScheduler
   }
 
-  test("scheduler should keep the decommission info where host was decommissioned") {
-    val scheduler = setupSchedulerForDecommissionTests(new SystemClock)
-
+  test("scheduler should keep the decommission state where host was decommissioned") {
+    val clock = new ManualClock(10000L)
+    val scheduler = setupSchedulerForDecommissionTests(clock, 2)
+    val oldTime = clock.getTimeMillis()
     scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false))
     scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true))
+
+    clock.advance(3000L)
     scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0 new", false))
     scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1 new", false))
 
-    assert(scheduler.getExecutorDecommissionInfo("executor0")
-      === Some(ExecutorDecommissionInfo("0 new", false)))
-    assert(scheduler.getExecutorDecommissionInfo("executor1")
-      === Some(ExecutorDecommissionInfo("1", true)))
-    assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
+    assert(scheduler.getExecutorDecommissionState("executor0")
+      === Some(ExecutorDecommissionState(oldTime, false)))
+    assert(scheduler.getExecutorDecommissionState("executor1")
+      === Some(ExecutorDecommissionState(oldTime, true)))
+    assert(scheduler.getExecutorDecommissionState("executor2").isEmpty)
   }
 
-  test("scheduler should eventually purge removed and decommissioned executors") {
+  test("test full decommissioning flow") {
     val clock = new ManualClock(10000L)
-    val scheduler = setupSchedulerForDecommissionTests(clock)
+    val scheduler = setupSchedulerForDecommissionTests(clock, 2)
+    val manager = stageToMockTaskSetManager(0)
+    // The task started should be running.
+    assert(manager.copiesRunning.take(2) === Array(1, 1))
 
     // executor 0 is decommissioned after loosing
-    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+    assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
     scheduler.executorLost("executor0", ExecutorExited(0, false, "normal"))
-    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+    assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
     scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
-    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+    assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
+
+    // 0th task just died above
+    assert(manager.copiesRunning.take(2) === Array(0, 1))
 
     assert(scheduler.executorsPendingDecommission.isEmpty)
     clock.advance(5000)
 
+    // executor1 hasn't been decommissioned yet
+    assert(scheduler.getExecutorDecommissionState("executor1").isEmpty)
+
     // executor 1 is decommissioned before loosing
-    assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
     scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
-    assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
+    assert(scheduler.getExecutorDecommissionState("executor1").isDefined)
     clock.advance(2000)
+
+    // executor1 is eventually lost
     scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
     assert(scheduler.decommissionedExecutorsRemoved.size === 1)
     assert(scheduler.executorsPendingDecommission.isEmpty)
+    // So now both the tasks are no longer running
+    assert(manager.copiesRunning.take(2) === Array(0, 0))
     clock.advance(2000)
-    // It hasn't been 60 seconds yet before removal
-    assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
+
+    // Decommission state should hang around a bit after removal ...
+    assert(scheduler.getExecutorDecommissionState("executor1").isDefined)
     scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
     clock.advance(2000)
     assert(scheduler.decommissionedExecutorsRemoved.size === 1)
-    assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
-    clock.advance(301000)
-    assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
+    assert(scheduler.getExecutorDecommissionState("executor1").isDefined)
+
+    // The default timeout for expiry is 300k milliseconds (5 minutes) which completes now,
+    // and the executor1's decommission state should finally be purged.
+    clock.advance(300000)
+    assert(scheduler.getExecutorDecommissionState("executor1").isEmpty)
     assert(scheduler.decommissionedExecutorsRemoved.isEmpty)
+
+    // Now give it some resources and both tasks should be rerun
+    val taskDescriptions = taskScheduler.resourceOffers(IndexedSeq(
+      WorkerOffer("executor2", "host2", 1), WorkerOffer("executor3", "host3", 1))).flatten
+    assert(taskDescriptions.size === 2)
+    assert(taskDescriptions.map(_.index).sorted == Seq(0, 1))
+    assert(manager.copiesRunning.take(2) === Array(1, 1))
   }
 
   /**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c6f8fa5..86d4e92 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -41,7 +41,7 @@ import org.apache.spark.resource.TestResourceIDs._
 import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{AccumulatorV2, ManualClock}
+import org.apache.spark.util.{AccumulatorV2, Clock, ManualClock, SystemClock}
 
 class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
   extends DAGScheduler(sc) {
@@ -109,8 +109,11 @@ object FakeRackUtil {
  * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
  * to work, and these are required for locality in TaskSetManager.
  */
-class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
-  extends TaskSchedulerImpl(sc)
+class FakeTaskScheduler(
+    sc: SparkContext,
+    clock: Clock,
+    liveExecutors: (String, String)* /* execId, host */)
+  extends TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
 {
   val startedTasks = new ArrayBuffer[Long]
   val endedTasks = new mutable.HashMap[Long, TaskEndReason]
@@ -120,6 +123,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
 
   val executors = new mutable.HashMap[String, String]
 
+  def this(sc: SparkContext, liveExecutors: (String, String)*) = {
+    this(sc, new SystemClock, liveExecutors: _*)
+  }
+
   // this must be initialized before addExecutor
   override val defaultRackValue: Option[String] = Some("default")
   for ((execId, host) <- liveExecutors) {
@@ -1974,14 +1981,16 @@ class TaskSetManagerSuite
   test("SPARK-21040: Check speculative tasks are launched when an executor is decommissioned" +
     " and the tasks running on it cannot finish within EXECUTOR_DECOMMISSION_KILL_INTERVAL") {
     sc = new SparkContext("local", "test")
-    sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
+    val clock = new ManualClock()
+    sched = new FakeTaskScheduler(sc, clock,
+      ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
+    sched.backend = mock(classOf[SchedulerBackend])
     val taskSet = FakeTask.createTaskSet(4)
     sc.conf.set(config.SPECULATION_ENABLED, true)
     sc.conf.set(config.SPECULATION_MULTIPLIER, 1.5)
     sc.conf.set(config.SPECULATION_QUANTILE, 0.5)
     sc.conf.set(config.EXECUTOR_DECOMMISSION_KILL_INTERVAL.key, "5s")
-    val clock = new ManualClock()
-    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
+    val manager = sched.createTaskSetManager(taskSet, MAX_TASK_FAILURES)
     val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
       task.metrics.internalAccums
     }
@@ -2017,13 +2026,13 @@ class TaskSetManagerSuite
     assert(!manager.checkSpeculatableTasks(0))
     assert(sched.speculativeTasks.toSet === Set())
 
-    // decommission exec-2. All tasks running on exec-2 (i.e. TASK 2,3)  will be added to
-    // executorDecommissionSpeculationTriggerTimeoutOpt
+    // decommission exec-2. All tasks running on exec-2 (i.e. TASK 2,3) will be now
+    // checked if they should be speculated.
     // (TASK 2 -> 15, TASK 3 -> 15)
-    manager.executorDecommission("exec2")
-    assert(manager.tidToExecutorKillTimeMapping.keySet === Set(2, 3))
-    assert(manager.tidToExecutorKillTimeMapping(2) === 15*1000)
-    assert(manager.tidToExecutorKillTimeMapping(3) === 15*1000)
+    sched.executorDecommission("exec2", ExecutorDecommissionInfo("decom",
+      isHostDecommissioned = false))
+    assert(sched.getExecutorDecommissionState("exec2").map(_.startTime) ===
+      Some(clock.getTimeMillis()))
 
     assert(manager.checkSpeculatableTasks(0))
     // TASK 2 started at t=0s, so it can still finish before t=15s (Median task runtime = 10s)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org