You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2020/07/30 18:59:19 UTC

[spark] branch master updated: [SPARK-32199][SPARK-32198] Reduce job failures during decommissioning

This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 366a178  [SPARK-32199][SPARK-32198] Reduce job failures during decommissioning
366a178 is described below

commit 366a1789333bac97643159857a206bcd773761a4
Author: Devesh Agrawal <de...@gmail.com>
AuthorDate: Thu Jul 30 11:58:11 2020 -0700

    [SPARK-32199][SPARK-32198] Reduce job failures during decommissioning
    
    ### What changes were proposed in this pull request?
    
    This PR reduces the prospect of a job loss during decommissioning. It
    fixes two holes in the current decommissioning framework:
    
    - (a) Loss of decommissioned executors is not treated as a job failure:
    We know that the decommissioned executor would be dying soon, so its death is
    clearly not caused by the application.
    
    - (b) Shuffle files on the decommissioned host are cleared when the
    first fetch failure is detected from a decommissioned host: This is a
    bit tricky in terms of when to clear the shuffle state ? Ideally you
    want to clear it the millisecond before the shuffle service on the node
    dies (or the executor dies when there is no external shuffle service) --
    too soon and it could lead to some wastage and too late would lead to
    fetch failures.
    
      The approach here is to do this clearing when the very first fetch
    failure is observed on the decommissioned block manager, without waiting for
    other blocks to also signal a failure.
    
    ### Why are the changes needed?
    
    Without them decommissioning a lot of executors at a time leads to job failures.
    
    ### Code overview
    
    The task scheduler tracks the executors that were decommissioned along with their
    `ExecutorDecommissionInfo`. This information is used by: (a) For handling a `ExecutorProcessLost` error, or (b) by the `DAGScheduler` when handling a fetch failure.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added a new unit test `DecommissionWorkerSuite` to test the new behavior by exercising the Master-Worker decommissioning. I chose to add a new test since the setup logic was quite different from the existing `WorkerDecommissionSuite`. I am open to changing the name of the newly added test suite :-)
    
    ### Questions for reviewers
    - Should I add a feature flag to guard these two behaviors ? They seem safe to me that they should only get triggered by decommissioning, but you never know :-).
    
    Closes #29014 from agrawaldevesh/decom_harden.
    
    Authored-by: Devesh Agrawal <de...@gmail.com>
    Signed-off-by: Holden Karau <hk...@apple.com>
---
 .../org/apache/spark/scheduler/DAGScheduler.scala  |  19 +-
 .../spark/scheduler/ExecutorLossReason.scala       |   7 +-
 .../org/apache/spark/scheduler/TaskScheduler.scala |   5 +
 .../apache/spark/scheduler/TaskSchedulerImpl.scala |  37 +-
 .../apache/spark/scheduler/TaskSetManager.scala    |   1 +
 .../spark/deploy/DecommissionWorkerSuite.scala     | 424 +++++++++++++++++++++
 .../apache/spark/scheduler/DAGSchedulerSuite.scala |   4 +
 .../scheduler/ExternalClusterManagerSuite.scala    |   2 +
 .../spark/scheduler/TaskSchedulerImplSuite.scala   |  47 +++
 9 files changed, 539 insertions(+), 7 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 2503ae0..6b376cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1821,10 +1821,19 @@ private[spark] class DAGScheduler(
 
           // TODO: mark the executor as failed only if there were lots of fetch failures on it
           if (bmAddress != null) {
-            val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled &&
-              unRegisterOutputOnHostOnFetchFailure) {
-              // We had a fetch failure with the external shuffle service, so we
-              // assume all shuffle data on the node is bad.
+            val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
+            val isHostDecommissioned = taskScheduler
+              .getExecutorDecommissionInfo(bmAddress.executorId)
+              .exists(_.isHostDecommissioned)
+
+            // Shuffle output of all executors on host `bmAddress.host` may be lost if:
+            // - External shuffle service is enabled, so we assume that all shuffle data on node is
+            //   bad.
+            // - Host is decommissioned, thus all executors on that host will die.
+            val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled ||
+              isHostDecommissioned
+            val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost
+              && unRegisterOutputOnHostOnFetchFailure) {
               Some(bmAddress.host)
             } else {
               // Unregister shuffle data just for one executor (we don't have any
@@ -2339,7 +2348,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
 
     case ExecutorLost(execId, reason) =>
       val workerLost = reason match {
-        case ExecutorProcessLost(_, true) => true
+        case ExecutorProcessLost(_, true, _) => true
         case _ => false
       }
       dagScheduler.handleExecutorLost(execId, workerLost)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 4141ed7..671deda 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -54,9 +54,14 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los
 /**
  * @param _message human readable loss reason
  * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service)
+ * @param causedByApp whether the loss of the executor is the fault of the running app.
+ *                    (assumed true by default unless known explicitly otherwise)
  */
 private[spark]
-case class ExecutorProcessLost(_message: String = "Worker lost", workerLost: Boolean = false)
+case class ExecutorProcessLost(
+    _message: String = "Executor Process Lost",
+    workerLost: Boolean = false,
+    causedByApp: Boolean = true)
   extends ExecutorLossReason(_message)
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index b29458c..1101d06 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -104,6 +104,11 @@ private[spark] trait TaskScheduler {
   def executorDecommission(executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit
 
   /**
+   * If an executor is decommissioned, return its corresponding decommission info
+   */
+  def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo]
+
+  /**
    * Process a lost executor
    */
   def executorLost(executorId: String, reason: ExecutorLossReason): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 510318a..b734d9f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -136,6 +136,8 @@ private[spark] class TaskSchedulerImpl(
   // IDs of the tasks running on each executor
   private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
 
+  private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
+
   def runningTasksByExecutors: Map[String, Int] = synchronized {
     executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
   }
@@ -939,12 +941,43 @@ private[spark] class TaskSchedulerImpl(
 
   override def executorDecommission(
       executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {
+    synchronized {
+      // Don't bother noting decommissioning for executors that we don't know about
+      if (executorIdToHost.contains(executorId)) {
+        // The scheduler can get multiple decommission updates from multiple sources,
+        // and some of those can have isHostDecommissioned false. We merge them such that
+        // if we heard isHostDecommissioned ever true, then we keep that one since it is
+        // most likely coming from the cluster manager and thus authoritative
+        val oldDecomInfo = executorsPendingDecommission.get(executorId)
+        if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
+          executorsPendingDecommission(executorId) = decommissionInfo
+        }
+      }
+    }
     rootPool.executorDecommission(executorId)
     backend.reviveOffers()
   }
 
-  override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {
+  override def getExecutorDecommissionInfo(executorId: String)
+    : Option[ExecutorDecommissionInfo] = synchronized {
+      executorsPendingDecommission.get(executorId)
+  }
+
+  override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
     var failedExecutor: Option[String] = None
+    val reason = givenReason match {
+      // Handle executor process loss due to decommissioning
+      case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) =>
+        val executorDecommissionInfo = getExecutorDecommissionInfo(executorId)
+        ExecutorProcessLost(
+          message,
+          // Also mark the worker lost if we know that the host was decommissioned
+          origWorkerLost || executorDecommissionInfo.exists(_.isHostDecommissioned),
+          // Executor loss is certainly not caused by app if we knew that this executor is being
+          // decommissioned
+          causedByApp = executorDecommissionInfo.isEmpty && origCausedByApp)
+      case e => e
+    }
 
     synchronized {
       if (executorIdToRunningTaskIds.contains(executorId)) {
@@ -1033,6 +1066,8 @@ private[spark] class TaskSchedulerImpl(
       }
     }
 
+    executorsPendingDecommission -= executorId
+
     if (reason != LossReasonPending) {
       executorIdToHost -= executorId
       rootPool.executorLost(executorId, host, reason)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 4b31ff0..d69f358 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -985,6 +985,7 @@ private[spark] class TaskSetManager(
       val exitCausedByApp: Boolean = reason match {
         case exited: ExecutorExited => exited.exitCausedByApp
         case ExecutorKilled => false
+        case ExecutorProcessLost(_, _, false) => false
         case _ => true
       }
       handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp,
diff --git a/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala
new file mode 100644
index 0000000..ee9a6be
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala
@@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.concurrent.duration._
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState, WorkerDecommission}
+import org.apache.spark.deploy.master.{ApplicationInfo, Master, WorkerInfo}
+import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.ExternalBlockHandler
+import org.apache.spark.rpc.{RpcAddress, RpcEnv}
+import org.apache.spark.scheduler._
+import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+
+class DecommissionWorkerSuite
+  extends SparkFunSuite
+    with Logging
+    with LocalSparkContext
+    with BeforeAndAfterEach {
+
+  private var masterAndWorkerConf: SparkConf = null
+  private var masterAndWorkerSecurityManager: SecurityManager = null
+  private var masterRpcEnv: RpcEnv = null
+  private var master: Master = null
+  private var workerIdToRpcEnvs: mutable.HashMap[String, RpcEnv] = null
+  private var workers: mutable.ArrayBuffer[Worker] = null
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    masterAndWorkerConf = new SparkConf()
+      .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true)
+    masterAndWorkerSecurityManager = new SecurityManager(masterAndWorkerConf)
+    masterRpcEnv = RpcEnv.create(
+      Master.SYSTEM_NAME,
+      "localhost",
+      0,
+      masterAndWorkerConf,
+      masterAndWorkerSecurityManager)
+    master = makeMaster()
+    workerIdToRpcEnvs = mutable.HashMap.empty
+    workers = mutable.ArrayBuffer.empty
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      masterRpcEnv.shutdown()
+      workerIdToRpcEnvs.values.foreach(_.shutdown())
+      workerIdToRpcEnvs.clear()
+      master.stop()
+      workers.foreach(_.stop())
+      workers.clear()
+      masterRpcEnv = null
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  test("decommission workers should not result in job failure") {
+    val maxTaskFailures = 2
+    val numTimesToKillWorkers = maxTaskFailures + 1
+    val numWorkers = numTimesToKillWorkers + 1
+    createWorkers(numWorkers)
+
+    // Here we will have a single task job and we will keep decommissioning (and killing) the
+    // worker running that task K times. Where K is more than the maxTaskFailures. Since the worker
+    // is notified of the decommissioning, the task failures can be ignored and not fail
+    // the job.
+
+    sc = createSparkContext(config.TASK_MAX_FAILURES.key -> maxTaskFailures.toString)
+    val executorIdToWorkerInfo = getExecutorToWorkerAssignments
+    val taskIdsKilled = new ConcurrentHashMap[Long, Boolean]
+    val listener = new RootStageAwareListener {
+      override def handleRootTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+        val taskInfo = taskStart.taskInfo
+        if (taskIdsKilled.size() < numTimesToKillWorkers) {
+          val workerInfo = executorIdToWorkerInfo(taskInfo.executorId)
+          decommissionWorkerOnMaster(workerInfo, "partition 0 must die")
+          killWorkerAfterTimeout(workerInfo, 1)
+          taskIdsKilled.put(taskInfo.taskId, true)
+        }
+      }
+    }
+    TestUtils.withListener(sc, listener) { _ =>
+      val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
+        Thread.sleep(5 * 1000L); 1
+      }.count()
+      assert(jobResult === 1)
+    }
+    // single task job that gets to run numTimesToKillWorkers + 1 times.
+    assert(listener.getTasksFinished().size === numTimesToKillWorkers + 1)
+    listener.rootTasksStarted.asScala.foreach { taskInfo =>
+      assert(taskInfo.index == 0, s"Unknown task index ${taskInfo.index}")
+    }
+    listener.rootTasksEnded.asScala.foreach { taskInfo =>
+      assert(taskInfo.index === 0, s"Expected task index ${taskInfo.index} to be 0")
+      // If a task has been killed then it shouldn't be successful
+      val taskSuccessExpected = !taskIdsKilled.getOrDefault(taskInfo.taskId, false)
+      val taskSuccessActual = taskInfo.successful
+      assert(taskSuccessActual === taskSuccessExpected,
+        s"Expected task success $taskSuccessActual == $taskSuccessExpected")
+    }
+  }
+
+  test("decommission workers ensure that shuffle output is regenerated even with shuffle service") {
+    createWorkers(2)
+    val ss = new ExternalShuffleServiceHolder()
+
+    sc = createSparkContext(
+      config.Tests.TEST_NO_STAGE_RETRY.key -> "true",
+      config.SHUFFLE_MANAGER.key -> "sort",
+      config.SHUFFLE_SERVICE_ENABLED.key -> "true",
+      config.SHUFFLE_SERVICE_PORT.key -> ss.getPort.toString
+    )
+
+    // Here we will create a 2 stage job: The first stage will have two tasks and the second stage
+    // will have one task. The two tasks in the first stage will be long and short. We decommission
+    // and kill the worker after the short task is done. Eventually the driver should get the
+    // executor lost signal for the short task executor. This should trigger regenerating
+    // the shuffle output since we cleanly decommissioned the executor, despite running with an
+    // external shuffle service.
+    try {
+      val executorIdToWorkerInfo = getExecutorToWorkerAssignments
+      val workerForTask0Decommissioned = new AtomicBoolean(false)
+      // single task job
+      val listener = new RootStageAwareListener {
+        override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+          val taskInfo = taskEnd.taskInfo
+          if (taskInfo.index == 0) {
+            if (workerForTask0Decommissioned.compareAndSet(false, true)) {
+              val workerInfo = executorIdToWorkerInfo(taskInfo.executorId)
+              decommissionWorkerOnMaster(workerInfo, "Kill early done map worker")
+              killWorkerAfterTimeout(workerInfo, 0)
+              logInfo(s"Killed the node ${workerInfo.hostPort} that was running the early task")
+            }
+          }
+        }
+      }
+      TestUtils.withListener(sc, listener) { _ =>
+        val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
+          val sleepTimeSeconds = if (pid == 0) 1 else 10
+          Thread.sleep(sleepTimeSeconds * 1000L)
+          List(1).iterator
+        }, preservesPartitioning = true).repartition(1).sum()
+        assert(jobResult === 2)
+      }
+      val tasksSeen = listener.getTasksFinished()
+      // 4 tasks: 2 from first stage, one retry due to decom, one more from the second stage.
+      assert(tasksSeen.size === 4, s"Expected 4 tasks but got $tasksSeen")
+      listener.rootTasksStarted.asScala.foreach { taskInfo =>
+        assert(taskInfo.index <= 1, s"Expected ${taskInfo.index} <= 1")
+        assert(taskInfo.successful, s"Task ${taskInfo.index} should be successful")
+      }
+      val tasksEnded = listener.rootTasksEnded.asScala
+      tasksEnded.filter(_.index != 0).foreach { taskInfo =>
+        assert(taskInfo.attemptNumber === 0, "2nd task should succeed on 1st attempt")
+      }
+      val firstTaskAttempts = tasksEnded.filter(_.index == 0)
+      assert(firstTaskAttempts.size > 1, s"Task 0 should have multiple attempts")
+    } finally {
+      ss.close()
+    }
+  }
+
+  test("decommission workers ensure that fetch failures lead to rerun") {
+    createWorkers(2)
+    sc = createSparkContext(
+      config.Tests.TEST_NO_STAGE_RETRY.key -> "false",
+      config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true")
+    val executorIdToWorkerInfo = getExecutorToWorkerAssignments
+    val executorToDecom = executorIdToWorkerInfo.keysIterator.next
+
+    // The task code below cannot call executorIdToWorkerInfo, so we need to pre-compute
+    // the worker to decom to force it to be serialized into the task.
+    val workerToDecom = executorIdToWorkerInfo(executorToDecom)
+
+    // The setup of this job is similar to the one above: 2 stage job with first stage having
+    // long and short tasks. Except that we want the shuffle output to be regenerated on a
+    // fetch failure instead of an executor lost. Since it is hard to "trigger a fetch failure",
+    // we manually raise the FetchFailed exception when the 2nd stage's task runs and require that
+    // fetch failure to trigger a recomputation.
+    logInfo(s"Will try to decommission the task running on executor $executorToDecom")
+    val listener = new RootStageAwareListener {
+      override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        val taskInfo = taskEnd.taskInfo
+        if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
+          taskEnd.stageAttemptId == 0) {
+          decommissionWorkerOnMaster(workerToDecom,
+            "decommission worker after task on it is done")
+        }
+      }
+    }
+    TestUtils.withListener(sc, listener) { _ =>
+      val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
+        val executorId = SparkEnv.get.executorId
+        val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
+        Thread.sleep(sleepTimeSeconds * 1000L)
+        List(1).iterator
+      }, preservesPartitioning = true)
+        .repartition(1).mapPartitions(iter => {
+        val context = TaskContext.get()
+        if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
+          // MapIndex is explicitly -1 to force the entire host to be decommissioned
+          // However, this will cause both the tasks in the preceding stage since the host here is
+          // "localhost" (shortcoming of this single-machine unit test in that all the workers
+          // are actually on the same host)
+          throw new FetchFailedException(BlockManagerId(executorToDecom,
+            workerToDecom.host, workerToDecom.port), 0, 0, -1, 0, "Forcing fetch failure")
+        }
+        val sumVal: List[Int] = List(iter.sum)
+        sumVal.iterator
+      }, preservesPartitioning = true)
+        .sum()
+      assert(jobResult === 2)
+    }
+    // 6 tasks: 2 from first stage, 2 rerun again from first stage, 2nd stage attempt 1 and 2.
+    val tasksSeen = listener.getTasksFinished()
+    assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen")
+  }
+
+  private abstract class RootStageAwareListener extends SparkListener {
+    private var rootStageId: Option[Int] = None
+    private val tasksFinished = new ConcurrentLinkedQueue[String]()
+    private val jobDone = new AtomicBoolean(false)
+    val rootTasksStarted = new ConcurrentLinkedQueue[TaskInfo]()
+    val rootTasksEnded = new ConcurrentLinkedQueue[TaskInfo]()
+
+    protected def isRootStageId(stageId: Int): Boolean =
+      (rootStageId.isDefined && rootStageId.get == stageId)
+
+    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+      if (stageSubmitted.stageInfo.parentIds.isEmpty && rootStageId.isEmpty) {
+        rootStageId = Some(stageSubmitted.stageInfo.stageId)
+      }
+    }
+
+    override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+      jobEnd.jobResult match {
+        case JobSucceeded => jobDone.set(true)
+      }
+    }
+
+    protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}
+
+    protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}
+
+    override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+      if (isRootStageId(taskStart.stageId)) {
+        rootTasksStarted.add(taskStart.taskInfo)
+        handleRootTaskStart(taskStart)
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+      val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
+        s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
+      logInfo(s"Task End $taskSignature")
+      tasksFinished.add(taskSignature)
+      if (isRootStageId(taskEnd.stageId)) {
+        rootTasksEnded.add(taskEnd.taskInfo)
+        handleRootTaskEnd(taskEnd)
+      }
+    }
+
+    def getTasksFinished(): Seq[String] = {
+      assert(jobDone.get(), "Job isn't successfully done yet")
+      tasksFinished.asScala.toSeq
+    }
+  }
+
+  private def getExecutorToWorkerAssignments: Map[String, WorkerInfo] = {
+    val executorIdToWorkerInfo = mutable.HashMap[String, WorkerInfo]()
+    master.workers.foreach { wi =>
+      assert(wi.executors.size <= 1, "There should be at most one executor per worker")
+      // Cast the executorId to string since the TaskInfo.executorId is a string
+      wi.executors.values.foreach { e =>
+        val executorIdString = e.id.toString
+        val oldWorkerInfo = executorIdToWorkerInfo.put(executorIdString, wi)
+        assert(oldWorkerInfo.isEmpty,
+          s"Executor $executorIdString already present on another worker ${oldWorkerInfo}")
+      }
+    }
+    executorIdToWorkerInfo.toMap
+  }
+
+  private def makeMaster(): Master = {
+    val master = new Master(
+      masterRpcEnv,
+      masterRpcEnv.address,
+      0,
+      masterAndWorkerSecurityManager,
+      masterAndWorkerConf)
+    masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
+    master
+  }
+
+  private def createWorkers(numWorkers: Int, cores: Int = 1, memory: Int = 1024): Unit = {
+    val workerRpcEnvs = (0 until numWorkers).map { i =>
+      RpcEnv.create(
+        Worker.SYSTEM_NAME + i,
+        "localhost",
+        0,
+        masterAndWorkerConf,
+        masterAndWorkerSecurityManager)
+    }
+    workers.clear()
+    val rpcAddressToRpcEnv: mutable.HashMap[RpcAddress, RpcEnv] = mutable.HashMap.empty
+    workerRpcEnvs.foreach { rpcEnv =>
+      val workDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString
+      val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address),
+        Worker.ENDPOINT_NAME, workDir, masterAndWorkerConf, masterAndWorkerSecurityManager)
+      rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker)
+      workers.append(worker)
+      val oldRpcEnv = rpcAddressToRpcEnv.put(rpcEnv.address, rpcEnv)
+      logInfo(s"Created a worker at ${rpcEnv.address} with workdir $workDir")
+      assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for ${rpcEnv.address}")
+    }
+    workerIdToRpcEnvs.clear()
+    // Wait until all workers register with master successfully
+    eventually(timeout(1.minute), interval(1.seconds)) {
+      val workersOnMaster = getMasterState.workers
+      val numWorkersCurrently = workersOnMaster.length
+      logInfo(s"Waiting for $numWorkers workers to come up: So far $numWorkersCurrently")
+      assert(numWorkersCurrently === numWorkers)
+      workersOnMaster.foreach { workerInfo =>
+        val rpcAddress = RpcAddress(workerInfo.host, workerInfo.port)
+        val rpcEnv = rpcAddressToRpcEnv(rpcAddress)
+        assert(rpcEnv != null, s"Cannot find the worker for $rpcAddress")
+        val oldRpcEnv = workerIdToRpcEnvs.put(workerInfo.id, rpcEnv)
+        assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for worker " +
+          s"${workerInfo.id}")
+      }
+    }
+    logInfo(s"Created ${workers.size} workers")
+  }
+
+  private def getMasterState: MasterStateResponse = {
+    master.self.askSync[MasterStateResponse](RequestMasterState)
+  }
+
+  private def getApplications(): Seq[ApplicationInfo] = {
+    getMasterState.activeApps
+  }
+
+  def decommissionWorkerOnMaster(workerInfo: WorkerInfo, reason: String): Unit = {
+    logInfo(s"Trying to decommission worker ${workerInfo.id} for reason `$reason`")
+    master.self.send(WorkerDecommission(workerInfo.id, workerInfo.endpoint))
+  }
+
+  def killWorkerAfterTimeout(workerInfo: WorkerInfo, secondsToWait: Int): Unit = {
+    val env = workerIdToRpcEnvs(workerInfo.id)
+    Thread.sleep(secondsToWait * 1000L)
+    env.shutdown()
+    env.awaitTermination()
+  }
+
+  def createSparkContext(extraConfs: (String, String)*): SparkContext = {
+    val conf = new SparkConf()
+      .setMaster(masterRpcEnv.address.toSparkURL)
+      .setAppName("test")
+      .setAll(extraConfs)
+    sc = new SparkContext(conf)
+    val appId = sc.applicationId
+    eventually(timeout(1.minute), interval(1.seconds)) {
+      val apps = getApplications()
+      assert(apps.size === 1)
+      assert(apps.head.id === appId)
+      assert(apps.head.getExecutorLimit === Int.MaxValue)
+    }
+    sc
+  }
+
+  private class ExternalShuffleServiceHolder() {
+    // The external shuffle service can start with default configs and not get polluted by the
+    // other configs used in this test.
+    private val transportConf = SparkTransportConf.fromSparkConf(new SparkConf(),
+      "shuffle", numUsableCores = 2)
+    private val rpcHandler = new ExternalBlockHandler(transportConf, null)
+    private val transportContext = new TransportContext(transportConf, rpcHandler)
+    private val server = transportContext.createServer()
+
+    def getPort: Int = server.getPort
+
+    def close(): Unit = {
+      Utils.tryLogNonFatalError {
+        server.close()
+      }
+      Utils.tryLogNonFatalError {
+        rpcHandler.close()
+      }
+      Utils.tryLogNonFatalError {
+        transportContext.close()
+      }
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 45af0d0..c829006 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -178,6 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
     override def executorDecommission(
       executorId: String,
       decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+    override def getExecutorDecommissionInfo(
+      executorId: String): Option[ExecutorDecommissionInfo] = None
   }
 
   /**
@@ -785,6 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
       override def executorDecommission(
         executorId: String,
         decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+      override def getExecutorDecommissionInfo(
+        executorId: String): Option[ExecutorDecommissionInfo] = None
     }
     val noKillScheduler = new DAGScheduler(
       sc,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index b2a5f77..07d8867 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -101,4 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler {
   override def executorDecommission(
     executorId: String,
     decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+  override def getExecutorDecommissionInfo(
+    executorId: String): Option[ExecutorDecommissionInfo] = None
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 9ca3ce9..e583645 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -1802,6 +1802,53 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     assert(2 == taskDescriptions.head.resources(GPU).addresses.size)
   }
 
+  private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = {
+    val taskScheduler = setupSchedulerWithMaster(
+      s"local[2]",
+      config.CPUS_PER_TASK.key -> 1.toString)
+    taskScheduler.submitTasks(FakeTask.createTaskSet(2))
+    val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1),
+      WorkerOffer("executor1", "host1", 1))
+    val taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
+    assert(taskDescriptions.map(_.executorId).sorted === Seq("executor0", "executor1"))
+    taskScheduler
+  }
+
+  test("scheduler should keep the decommission info where host was decommissioned") {
+    val scheduler = setupSchedulerForDecommissionTests()
+
+    scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false))
+    scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true))
+    scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0 new", false))
+    scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1 new", false))
+
+    assert(scheduler.getExecutorDecommissionInfo("executor0")
+      === Some(ExecutorDecommissionInfo("0 new", false)))
+    assert(scheduler.getExecutorDecommissionInfo("executor1")
+      === Some(ExecutorDecommissionInfo("1", true)))
+    assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
+  }
+
+  test("scheduler should ignore decommissioning of removed executors") {
+    val scheduler = setupSchedulerForDecommissionTests()
+
+    // executor 0 is decommissioned after loosing
+    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+    scheduler.executorLost("executor0", ExecutorExited(0, false, "normal"))
+    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+    scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
+    assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
+
+    // executor 1 is decommissioned before loosing
+    assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
+    scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
+    assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
+    scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
+    assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
+    scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
+    assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
+  }
+
   /**
    * Used by tests to simulate a task failure. This calls the failure handler explicitly, to ensure
    * that all the state is updated when this method returns. Otherwise, there's no way to know when


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org