You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2015/07/20 19:29:04 UTC

spark git commit: [SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage

Repository: spark
Updated Branches:
  refs/heads/master c6fe9b4a1 -> 80e2568b2


[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage

https://issues.apache.org/jira/browse/SPARK-8103

cc kayousterhout (thanks for the extra test case)

Author: Imran Rashid <ir...@cloudera.com>
Author: Kay Ousterhout <ka...@gmail.com>
Author: Imran Rashid <sq...@users.noreply.github.com>

Closes #6750 from squito/SPARK-8103 and squashes the following commits:

fb3acfc [Imran Rashid] fix log msg
e01b7aa [Imran Rashid] fix some comments, style
584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr
e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103
6bc23af [Imran Rashid] update log msg
4470fa1 [Imran Rashid] rename
c04707e [Imran Rashid] style
88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts
d7f1ef2 [Imran Rashid] get rid of activeTaskSets
a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103
906d626 [Imran Rashid] fix merge
109900e [Imran Rashid] Merge branch 'master' into SPARK-8103
c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id"
f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103
baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id
19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments
a5f7c8c [Imran Rashid] remove comment for reviewers
227b40d [Imran Rashid] style
517b6e5 [Imran Rashid] get rid of SparkIllegalStateException
b2faef5 [Imran Rashid] faster check for conflicting task sets
6542b42 [Imran Rashid] remove extra stageAttemptId
ada7726 [Imran Rashid] reviewer feedback
d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103
46bc26a [Imran Rashid] more cleanup of debug garbage
cb245da [Imran Rashid] finally found the issue ... clean up debug stuff
8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103
89a59b6 [Imran Rashid] more printlns ...
9601b47 [Imran Rashid] more debug printlns
ecb4e7d [Imran Rashid] debugging printlns
b6bc248 [Imran Rashid] style
55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer
7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean
883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue
6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts
06a0af6 [Imran Rashid] ignore for jenkins
c443def [Imran Rashid] better fix and simpler test case
28d70aa [Imran Rashid] wip on getting a better test case ...
a9bf31f [Imran Rashid] wip


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80e2568b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80e2568b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80e2568b

Branch: refs/heads/master
Commit: 80e2568b25780a7094199239da8ad6cfb6efc9f7
Parents: c6fe9b4
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Jul 20 10:28:32 2015 -0700
Committer: Kay Ousterhout <ka...@gmail.com>
Committed: Mon Jul 20 10:28:32 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/scheduler/DAGScheduler.scala   |  78 +++++-----
 .../org/apache/spark/scheduler/ResultTask.scala |   3 +-
 .../apache/spark/scheduler/ShuffleMapTask.scala |   5 +-
 .../scala/org/apache/spark/scheduler/Task.scala |   5 +-
 .../spark/scheduler/TaskSchedulerImpl.scala     |  99 ++++++++-----
 .../org/apache/spark/scheduler/TaskSet.scala    |   4 +-
 .../cluster/CoarseGrainedSchedulerBackend.scala |   5 +-
 .../spark/scheduler/DAGSchedulerSuite.scala     | 141 +++++++++++++++++++
 .../org/apache/spark/scheduler/FakeTask.scala   |   8 +-
 .../scheduler/NotSerializableFakeTask.scala     |   2 +-
 .../spark/scheduler/TaskContextSuite.scala      |   4 +-
 .../scheduler/TaskSchedulerImplSuite.scala      | 113 ++++++++++++++-
 .../spark/scheduler/TaskSetManagerSuite.scala   |   2 +-
 13 files changed, 383 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dd55cd8..71a219a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -857,7 +857,6 @@ class DAGScheduler(
     // Get our pending tasks and remember them in our pendingTasks entry
     stage.pendingTasks.clear()
 
-
     // First figure out the indexes of partition ids to compute.
     val partitionsToCompute: Seq[Int] = {
       stage match {
@@ -918,7 +917,7 @@ class DAGScheduler(
           partitionsToCompute.map { id =>
             val locs = getPreferredLocs(stage.rdd, id)
             val part = stage.rdd.partitions(id)
-            new ShuffleMapTask(stage.id, taskBinary, part, locs)
+            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
           }
 
         case stage: ResultStage =>
@@ -927,7 +926,7 @@ class DAGScheduler(
             val p: Int = job.partitions(id)
             val part = stage.rdd.partitions(p)
             val locs = getPreferredLocs(stage.rdd, p)
-            new ResultTask(stage.id, taskBinary, part, locs, id)
+            new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
           }
       }
     } catch {
@@ -1069,10 +1068,11 @@ class DAGScheduler(
             val execId = status.location.executorId
             logDebug("ShuffleMapTask finished on " + execId)
             if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
-              logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+              logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
             } else {
               shuffleStage.addOutputLoc(smt.partitionId, status)
             }
+
             if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
               markStageAsFinished(shuffleStage)
               logInfo("looking for newly runnable stages")
@@ -1132,38 +1132,48 @@ class DAGScheduler(
         val failedStage = stageIdToStage(task.stageId)
         val mapStage = shuffleToMapStage(shuffleId)
 
-        // It is likely that we receive multiple FetchFailed for a single stage (because we have
-        // multiple tasks running concurrently on different executors). In that case, it is possible
-        // the fetch failure has already been handled by the scheduler.
-        if (runningStages.contains(failedStage)) {
-          logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
-            s"due to a fetch failure from $mapStage (${mapStage.name})")
-          markStageAsFinished(failedStage, Some(failureMessage))
-        }
+        if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
+          logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
+            s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
+            s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
+        } else {
 
-        if (disallowStageRetryForTest) {
-          abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
-        } else if (failedStages.isEmpty) {
-          // Don't schedule an event to resubmit failed stages if failed isn't empty, because
-          // in that case the event will already have been scheduled.
-          // TODO: Cancel running tasks in the stage
-          logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
-            s"$failedStage (${failedStage.name}) due to fetch failure")
-          messageScheduler.schedule(new Runnable {
-            override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
-          }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
-        }
-        failedStages += failedStage
-        failedStages += mapStage
-        // Mark the map whose fetch failed as broken in the map stage
-        if (mapId != -1) {
-          mapStage.removeOutputLoc(mapId, bmAddress)
-          mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
-        }
+          // It is likely that we receive multiple FetchFailed for a single stage (because we have
+          // multiple tasks running concurrently on different executors). In that case, it is
+          // possible the fetch failure has already been handled by the scheduler.
+          if (runningStages.contains(failedStage)) {
+            logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
+              s"due to a fetch failure from $mapStage (${mapStage.name})")
+            markStageAsFinished(failedStage, Some(failureMessage))
+          } else {
+            logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
+              s"longer running")
+          }
+
+          if (disallowStageRetryForTest) {
+            abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+          } else if (failedStages.isEmpty) {
+            // Don't schedule an event to resubmit failed stages if failed isn't empty, because
+            // in that case the event will already have been scheduled.
+            // TODO: Cancel running tasks in the stage
+            logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
+              s"$failedStage (${failedStage.name}) due to fetch failure")
+            messageScheduler.schedule(new Runnable {
+              override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+            }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+          }
+          failedStages += failedStage
+          failedStages += mapStage
+          // Mark the map whose fetch failed as broken in the map stage
+          if (mapId != -1) {
+            mapStage.removeOutputLoc(mapId, bmAddress)
+            mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+          }
 
-        // TODO: mark the executor as failed only if there were lots of fetch failures on it
-        if (bmAddress != null) {
-          handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+          // TODO: mark the executor as failed only if there were lots of fetch failures on it
+          if (bmAddress != null) {
+            handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+          }
         }
 
       case commitDenied: TaskCommitDenied =>

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c9a1241..9c2606e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
+    stageAttemptId: Int,
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient locs: Seq[TaskLocation],
     val outputId: Int)
-  extends Task[U](stageId, partition.index) with Serializable {
+  extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index bd3dd23..14c8c00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
+    stageAttemptId: Int,
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, partition.index) with Logging {
+  extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */
   def this(partitionId: Int) {
-    this(0, null, new Partition { override def index: Int = 0 }, null)
+    this(0, 0, null, new Partition { override def index: Int = 0 }, null)
   }
 
   @transient private val preferredLocs: Seq[TaskLocation] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6a86f9d..76a19ae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
  * @param stageId id of the stage this task belongs to
  * @param partitionId index of the number in the RDD
  */
-private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+private[spark] abstract class Task[T](
+    val stageId: Int,
+    val stageAttemptId: Int,
+    var partitionId: Int) extends Serializable {
 
   /**
    * The key of the Map is the accumulator id and the value of the Map is the latest accumulator

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index ed3dde0..1705e7f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
 
   // TaskSetManagers are not thread safe, so any access to one should be synchronized
   // on this class.
-  val activeTaskSets = new HashMap[String, TaskSetManager]
+  private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
 
-  val taskIdToTaskSetId = new HashMap[Long, String]
+  private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
   val taskIdToExecutorId = new HashMap[Long, String]
 
   @volatile private var hasReceivedTask = false
@@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
     logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
     this.synchronized {
       val manager = createTaskSetManager(taskSet, maxTaskFailures)
-      activeTaskSets(taskSet.id) = manager
+      val stage = taskSet.stageId
+      val stageTaskSets =
+        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
+      stageTaskSets(taskSet.stageAttemptId) = manager
+      val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
+        ts.taskSet != taskSet && !ts.isZombie
+      }
+      if (conflictingTaskSet) {
+        throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
+          s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
+      }
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
 
       if (!isLocal && !hasReceivedTask) {
@@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
 
   override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
     logInfo("Cancelling stage " + stageId)
-    activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
-      // There are two possible cases here:
-      // 1. The task set manager has been created and some tasks have been scheduled.
-      //    In this case, send a kill signal to the executors to kill the task and then abort
-      //    the stage.
-      // 2. The task set manager has been created but no tasks has been scheduled. In this case,
-      //    simply abort the stage.
-      tsm.runningTasksSet.foreach { tid =>
-        val execId = taskIdToExecutorId(tid)
-        backend.killTask(tid, execId, interruptThread)
+    taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
+      attempts.foreach { case (_, tsm) =>
+        // There are two possible cases here:
+        // 1. The task set manager has been created and some tasks have been scheduled.
+        //    In this case, send a kill signal to the executors to kill the task and then abort
+        //    the stage.
+        // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+        //    simply abort the stage.
+        tsm.runningTasksSet.foreach { tid =>
+          val execId = taskIdToExecutorId(tid)
+          backend.killTask(tid, execId, interruptThread)
+        }
+        tsm.abort("Stage %s cancelled".format(stageId))
+        logInfo("Stage %d was cancelled".format(stageId))
       }
-      tsm.abort("Stage %s cancelled".format(stageId))
-      logInfo("Stage %d was cancelled".format(stageId))
     }
   }
 
@@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
    * cleaned up.
    */
   def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
-    activeTaskSets -= manager.taskSet.id
+    taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
+      taskSetsForStage -= manager.taskSet.stageAttemptId
+      if (taskSetsForStage.isEmpty) {
+        taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
+      }
+    }
     manager.parent.removeSchedulable(manager)
     logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
       .format(manager.taskSet.id, manager.parent.name))
@@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
           for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
             tasks(i) += task
             val tid = task.taskId
-            taskIdToTaskSetId(tid) = taskSet.taskSet.id
+            taskIdToTaskSetManager(tid) = taskSet
             taskIdToExecutorId(tid) = execId
             executorsByHost(host) += execId
             availableCpus(i) -= CPUS_PER_TASK
@@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
             failedExecutor = Some(execId)
           }
         }
-        taskIdToTaskSetId.get(tid) match {
-          case Some(taskSetId) =>
+        taskIdToTaskSetManager.get(tid) match {
+          case Some(taskSet) =>
             if (TaskState.isFinished(state)) {
-              taskIdToTaskSetId.remove(tid)
+              taskIdToTaskSetManager.remove(tid)
               taskIdToExecutorId.remove(tid)
             }
-            activeTaskSets.get(taskSetId).foreach { taskSet =>
-              if (state == TaskState.FINISHED) {
-                taskSet.removeRunningTask(tid)
-                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
-              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
-                taskSet.removeRunningTask(tid)
-                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
-              }
+            if (state == TaskState.FINISHED) {
+              taskSet.removeRunningTask(tid)
+              taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+            } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+              taskSet.removeRunningTask(tid)
+              taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
             }
           case None =>
             logError(
               ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
-               "likely the result of receiving duplicate task finished status updates)")
-              .format(state, tid))
+                "likely the result of receiving duplicate task finished status updates)")
+                .format(state, tid))
         }
       } catch {
         case e: Exception => logError("Exception in statusUpdate", e)
@@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
 
     val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
       taskMetrics.flatMap { case (id, metrics) =>
-        taskIdToTaskSetId.get(id)
-          .flatMap(activeTaskSets.get)
-          .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
+        taskIdToTaskSetManager.get(id).map { taskSetMgr =>
+          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+        }
       }
     }
     dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
 
   def error(message: String) {
     synchronized {
-      if (activeTaskSets.nonEmpty) {
+      if (taskSetsByStageIdAndAttempt.nonEmpty) {
         // Have each task set throw a SparkException with the error
-        for ((taskSetId, manager) <- activeTaskSets) {
+        for {
+          attempts <- taskSetsByStageIdAndAttempt.values
+          manager <- attempts.values
+        } {
           try {
             manager.abort(message)
           } catch {
@@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
 
   override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
 
+  private[scheduler] def taskSetManagerForAttempt(
+      stageId: Int,
+      stageAttemptId: Int): Option[TaskSetManager] = {
+    for {
+      attempts <- taskSetsByStageIdAndAttempt.get(stageId)
+      manager <- attempts.get(stageAttemptId)
+    } yield {
+      manager
+    }
+  }
+
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325..be8526b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -26,10 +26,10 @@ import java.util.Properties
 private[spark] class TaskSet(
     val tasks: Array[Task[_]],
     val stageId: Int,
-    val attempt: Int,
+    val stageAttemptId: Int,
     val priority: Int,
     val properties: Properties) {
-    val id: String = stageId + "." + attempt
+    val id: String = stageId + "." + stageAttemptId
 
   override def toString: String = "TaskSet " + id
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 0e3215d..f14c603 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
       for (task <- tasks.flatten) {
         val serializedTask = ser.serialize(task)
         if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
-          val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
-          scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
+          scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
             try {
               var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                 "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
                 "spark.akka.frameSize or using broadcast variables for large values."
               msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
                 AkkaUtils.reservedSizeBytes)
-              taskSet.abort(msg)
+              taskSetMgr.abort(msg)
             } catch {
               case e: Exception => logError("Exception in error callback", e)
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 4f2b0fa..86728cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -101,9 +101,15 @@ class DAGSchedulerSuite
   /** Length of time to wait while draining listener events. */
   val WAIT_TIMEOUT_MILLIS = 10000
   val sparkListener = new SparkListener() {
+    val submittedStageInfos = new HashSet[StageInfo]
     val successfulStages = new HashSet[Int]
     val failedStages = new ArrayBuffer[Int]
     val stageByOrderOfExecution = new ArrayBuffer[Int]
+
+    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+      submittedStageInfos += stageSubmitted.stageInfo
+    }
+
     override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
       val stageInfo = stageCompleted.stageInfo
       stageByOrderOfExecution += stageInfo.stageId
@@ -150,6 +156,7 @@ class DAGSchedulerSuite
     // Enable local execution for this test
     val conf = new SparkConf().set("spark.localExecution.enabled", "true")
     sc = new SparkContext("local", "DAGSchedulerSuite", conf)
+    sparkListener.submittedStageInfos.clear()
     sparkListener.successfulStages.clear()
     sparkListener.failedStages.clear()
     failure = null
@@ -547,6 +554,140 @@ class DAGSchedulerSuite
     assert(sparkListener.failedStages.size == 1)
   }
 
+  /**
+   * This tests the case where another FetchFailed comes in while the map stage is getting
+   * re-run.
+   */
+  test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    submit(reduceRdd, Array(0, 1))
+
+    val mapStageId = 0
+    def countSubmittedMapStageAttempts(): Int = {
+      sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
+    }
+
+    // The map stage should have been submitted.
+    assert(countSubmittedMapStageAttempts() === 1)
+
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
+    // The MapOutputTracker should know about both map output locations.
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
+      Array("hostA", "hostB"))
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
+      Array("hostA", "hostB"))
+
+    // The first result task fails, with a fetch failure for the output from the first mapper.
+    runEvent(CompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null,
+      Map[Long, Any](),
+      createFakeTaskInfo(),
+      null))
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assert(sparkListener.failedStages.contains(1))
+
+    // Trigger resubmission of the failed map stage.
+    runEvent(ResubmitFailedStages)
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+
+    // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
+    assert(countSubmittedMapStageAttempts() === 2)
+
+    // The second ResultTask fails, with a fetch failure for the output from the second mapper.
+    runEvent(CompletionEvent(
+      taskSets(1).tasks(1),
+      FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+      null,
+      Map[Long, Any](),
+      createFakeTaskInfo(),
+      null))
+
+    // Another ResubmitFailedStages event should not result in another attempt for the map
+    // stage being run concurrently.
+    // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it
+    // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
+    // desired event and our check.
+    runEvent(ResubmitFailedStages)
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assert(countSubmittedMapStageAttempts() === 2)
+
+  }
+
+  /**
+    * This tests the case where a late FetchFailed comes in after the map stage has finished getting
+    * retried and a new reduce stage starts running.
+    */
+  test("extremely late fetch failures don't cause multiple concurrent attempts for " +
+      "the same stage") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+    submit(reduceRdd, Array(0, 1))
+
+    def countSubmittedReduceStageAttempts(): Int = {
+      sparkListener.submittedStageInfos.count(_.stageId == 1)
+    }
+    def countSubmittedMapStageAttempts(): Int = {
+      sparkListener.submittedStageInfos.count(_.stageId == 0)
+    }
+
+    // The map stage should have been submitted.
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assert(countSubmittedMapStageAttempts() === 1)
+
+    // Complete the map stage.
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
+
+    // The reduce stage should have been submitted.
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assert(countSubmittedReduceStageAttempts() === 1)
+
+    // The first result task fails, with a fetch failure for the output from the first mapper.
+    runEvent(CompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null,
+      Map[Long, Any](),
+      createFakeTaskInfo(),
+      null))
+
+    // Trigger resubmission of the failed map stage and finish the re-started map task.
+    runEvent(ResubmitFailedStages)
+    complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+
+    // Because the map stage finished, another attempt for the reduce stage should have been
+    // submitted, resulting in 2 total attempts for each the map and the reduce stage.
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assert(countSubmittedMapStageAttempts() === 2)
+    assert(countSubmittedReduceStageAttempts() === 2)
+
+    // A late FetchFailed arrives from the second task in the original reduce stage.
+    runEvent(CompletionEvent(
+      taskSets(1).tasks(1),
+      FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+      null,
+      Map[Long, Any](),
+      createFakeTaskInfo(),
+      null))
+
+    // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
+    // the FetchFailed should have been ignored
+    runEvent(ResubmitFailedStages)
+
+    // The FetchFailed from the original reduce stage should be ignored.
+    assert(countSubmittedMapStageAttempts() === 2)
+  }
+
   test("ignore late map task completions") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0a7cb69..b3ca150 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
 
 import org.apache.spark.TaskContext
 
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
   override def runTask(context: TaskContext): Int = 0
 
   override def preferredLocations: Seq[TaskLocation] = prefLocs
@@ -31,12 +31,16 @@ object FakeTask {
    * locations for each task (given as varargs) if this sequence is not empty.
    */
   def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+    createTaskSet(numTasks, 0, prefLocs: _*)
+  }
+
+  def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
     if (prefLocs.size != 0 && prefLocs.size != numTasks) {
       throw new IllegalArgumentException("Wrong number of task locations")
     }
     val tasks = Array.tabulate[Task[_]](numTasks) { i =>
       new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
     }
-    new TaskSet(tasks, 0, 0, 0, null)
+    new TaskSet(tasks, 0, stageAttemptId, 0, null)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 9b92f8d..383855c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
  * A Task implementation that fails to serialize.
  */
 private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
-  extends Task[Array[Byte]](stageId, 0) {
+  extends Task[Array[Byte]](stageId, 0, 0) {
 
   override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
   override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7c1adc1..b9b0ecc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }
     val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
     val func = (c: TaskContext, i: Iterator[String]) => i.next()
-    val task = new ResultTask[String, String](
-      0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+    val task = new ResultTask[String, String](0, 0,
+      sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
     intercept[RuntimeException] {
       task.run(0, 0)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index a6d5232..c2edd4c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
     val taskScheduler = new TaskSchedulerImpl(sc)
     taskScheduler.initialize(new FakeSchedulerBackend)
     // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
-    val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+    new DAGScheduler(sc, taskScheduler) {
       override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
       override def executorAdded(execId: String, host: String) {}
     }
@@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
     val taskScheduler = new TaskSchedulerImpl(sc)
     taskScheduler.initialize(new FakeSchedulerBackend)
     // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
-    val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+    new DAGScheduler(sc, taskScheduler) {
       override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
       override def executorAdded(execId: String, host: String) {}
     }
@@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
     assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
   }
 
+  test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+    val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+    taskScheduler.setDAGScheduler(dagScheduler)
+    val attempt1 = FakeTask.createTaskSet(1, 0)
+    val attempt2 = FakeTask.createTaskSet(1, 1)
+    taskScheduler.submitTasks(attempt1)
+    intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
+
+    // OK to submit multiple if previous attempts are all zombie
+    taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+      .get.isZombie = true
+    taskScheduler.submitTasks(attempt2)
+    val attempt3 = FakeTask.createTaskSet(1, 2)
+    intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
+    taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
+      .get.isZombie = true
+    taskScheduler.submitTasks(attempt3)
+  }
+
+  test("don't schedule more tasks after a taskset is zombie") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+    new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+
+    val numFreeCores = 1
+    val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+    val attempt1 = FakeTask.createTaskSet(10)
+
+    // submit attempt 1, offer some resources, some tasks get scheduled
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(1 === taskDescriptions.length)
+
+    // now mark attempt 1 as a zombie
+    taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+      .get.isZombie = true
+
+    // don't schedule anything on another resource offer
+    val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(0 === taskDescriptions2.length)
+
+    // if we schedule another attempt for the same stage, it should get scheduled
+    val attempt2 = FakeTask.createTaskSet(10, 1)
+
+    // submit attempt 2, offer some resources, some tasks get scheduled
+    taskScheduler.submitTasks(attempt2)
+    val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(1 === taskDescriptions3.length)
+    val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
+    assert(mgr.taskSet.stageAttemptId === 1)
+  }
+
+  test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+    new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+
+    val numFreeCores = 10
+    val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+    val attempt1 = FakeTask.createTaskSet(10)
+
+    // submit attempt 1, offer some resources, some tasks get scheduled
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(10 === taskDescriptions.length)
+
+    // now mark attempt 1 as a zombie
+    val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
+    mgr1.isZombie = true
+
+    // don't schedule anything on another resource offer
+    val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(0 === taskDescriptions2.length)
+
+    // submit attempt 2
+    val attempt2 = FakeTask.createTaskSet(10, 1)
+    taskScheduler.submitTasks(attempt2)
+
+    // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
+    // already submitted, and then they finish)
+    taskScheduler.taskSetFinished(mgr1)
+
+    // now with another resource offer, we should still schedule all the tasks in attempt2
+    val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(10 === taskDescriptions3.length)
+
+    taskDescriptions3.foreach { task =>
+      val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
+      assert(mgr.taskSet.stageAttemptId === 1)
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/80e2568b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index cdae0d8..3abb99c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
 /**
  * A Task implementation that results in a large serialized task.
  */
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
   val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
   val random = new Random(0)
   random.nextBytes(randomBuffer)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org