You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/02/07 01:11:01 UTC

git commit: Merge pull request #321 from kayousterhout/ui_kill_fix. Closes #321.

Updated Branches:
  refs/heads/master 446403b63 -> 18ad59e2c


Merge pull request #321 from kayousterhout/ui_kill_fix. Closes #321.

Inform DAG scheduler about all started/finished tasks.

Previously, the DAG scheduler was not always informed
when tasks started and finished. The simplest example here
is for speculated tasks: the DAGScheduler was only told about
the first attempt of a task, meaning that SparkListeners were
also not told about multiple task attempts, so users can't see
what's going on with speculation in the UI.  The DAGScheduler
also wasn't always told about finished tasks, so in the UI, some
tasks will never be shown as finished (this occurs, for example,
if a task set gets killed).

The other problem is that the fairness accounting was wrong
-- the number of running tasks in a pool was decreased when a
task set was considered done, even if all of its tasks hadn't
yet finished.

Author: Kay Ousterhout <ka...@gmail.com>

== Merge branch commits ==

commit c8d547d0f7a17f5a193bef05f5872b9f475675c5
Author: Kay Ousterhout <ka...@gmail.com>
Date:   Wed Jan 15 16:47:33 2014 -0800

    Addressed Reynold's review comments.

    Always use a TaskEndReason (remove the option), and explicitly
    signal when we don't know the reason. Also, always tell
    DAGScheduler (and associated listeners) about started tasks, even
    when they're speculated.

commit 3fee1e2e3c06b975ff7f95d595448f38cce97a04
Author: Kay Ousterhout <ka...@gmail.com>
Date:   Wed Jan 8 22:58:13 2014 -0800

    Fixed broken test and improved logging

commit ff12fcaa2567c5d02b75a1d5db35687225bcd46f
Author: Kay Ousterhout <ka...@gmail.com>
Date:   Sun Dec 29 21:08:20 2013 -0800

    Inform DAG scheduler about all finished tasks.

    Previously, the DAG scheduler was not always informed
    when tasks finished. For example, when a task set was
    aborted, the DAG scheduler was never told when the tasks
    in that task set finished. The DAG scheduler was also
    never told about the completion of speculated tasks.
    This led to confusion with SparkListeners because information
    about the completion of those tasks was never passed on to
    the listeners (so in the UI, for example, some tasks will never
    be shown as finished).

    The other problem is that the fairness accounting was wrong
    -- the number of running tasks in a pool was decreased when a
    task set was considered done, even if all of its tasks hadn't
    yet finished.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/18ad59e2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/18ad59e2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/18ad59e2

Branch: refs/heads/master
Commit: 18ad59e2c6b7bd009e8ba5ebf8fcf99630863029
Parents: 446403b
Author: Kay Ousterhout <ka...@gmail.com>
Authored: Thu Feb 6 16:10:48 2014 -0800
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Thu Feb 6 16:10:48 2014 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskEndReason.scala  |  13 ++
 .../apache/spark/scheduler/DAGScheduler.scala   |   4 +-
 .../org/apache/spark/scheduler/StageInfo.scala  |   6 +
 .../spark/scheduler/TaskResultGetter.scala      |   8 +-
 .../spark/scheduler/TaskSchedulerImpl.scala     |  46 ++---
 .../apache/spark/scheduler/TaskSetManager.scala | 193 +++++++++----------
 .../spark/scheduler/ClusterSchedulerSuite.scala |  12 +-
 .../spark/scheduler/SparkListenerSuite.scala    |  41 +++-
 .../spark/scheduler/TaskSetManagerSuite.scala   |   4 +-
 9 files changed, 183 insertions(+), 144 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/TaskEndReason.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index faf6dcd..3fd6f5e 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,3 +53,16 @@ private[spark] case class ExceptionFailure(
 private[spark] case object TaskResultLost extends TaskEndReason
 
 private[spark] case object TaskKilled extends TaskEndReason
+
+/**
+ * The task failed because the executor that it was running on was lost. This may happen because
+ * the task crashed the JVM.
+ */
+private[spark] case object ExecutorLostFailure extends TaskEndReason
+
+/**
+ * We don't know why the task ended -- for example, because of a ClassNotFound exception when
+ * deserializing the task result.
+ */
+private[spark] case object UnknownReason extends TaskEndReason
+

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 237cbf4..8212415 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -954,8 +954,8 @@ class DAGScheduler(
         // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
 
       case other =>
-        // Unrecognized failure - abort all jobs depending on this stage
-        abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
+        // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
+        // will abort the job.
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e9f2198..c4d1ad5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,6 +21,12 @@ import scala.collection._
 
 import org.apache.spark.executor.TaskMetrics
 
+/**
+ * Stores information about a stage to pass from the scheduler to SparkListeners.
+ *
+ * taskInfos stores the metrics for all tasks that have completed, including redundant, speculated
+ * tasks.
+ */
 class StageInfo(
     stage: Stage,
     val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 35e9544..bdec08e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -57,7 +57,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
                  * between when the task ended and when we tried to fetch the result, or if the
                  * block manager had to flush the result. */
                 scheduler.handleFailedTask(
-                  taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+                  taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                 return
               }
               val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
@@ -80,13 +80,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
 
   def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
     serializedData: ByteBuffer) {
-    var reason: Option[TaskEndReason] = None
+    var reason : TaskEndReason = UnknownReason
     getTaskResultExecutor.execute(new Runnable {
       override def run() {
         try {
           if (serializedData != null && serializedData.limit() > 0) {
-            reason = Some(serializer.get().deserialize[TaskEndReason](
-              serializedData, getClass.getClassLoader))
+            reason = serializer.get().deserialize[TaskEndReason](
+              serializedData, getClass.getClassLoader)
           }
         } catch {
           case cnd: ClassNotFoundException =>

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 83ba584..5b52515 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -67,7 +67,6 @@ private[spark] class TaskSchedulerImpl(
 
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskIdToExecutorId = new HashMap[Long, String]
-  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
 
   @volatile private var hasReceivedTask = false
   @volatile private var hasLaunchedTask = false
@@ -142,7 +141,6 @@ private[spark] class TaskSchedulerImpl(
       val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
       activeTaskSets(taskSet.id) = manager
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
-      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
 
       if (!isLocal && !hasReceivedTask) {
         starvationTimer.scheduleAtFixedRate(new TimerTask() {
@@ -171,31 +169,25 @@ private[spark] class TaskSchedulerImpl(
       //    the stage.
       // 2. The task set manager has been created but no tasks has been scheduled. In this case,
       //    simply abort the stage.
-      val taskIds = taskSetTaskIds(tsm.taskSet.id)
-      if (taskIds.size > 0) {
-        taskIds.foreach { tid =>
-          val execId = taskIdToExecutorId(tid)
-          backend.killTask(tid, execId)
-        }
+      tsm.runningTasksSet.foreach { tid =>
+        val execId = taskIdToExecutorId(tid)
+        backend.killTask(tid, execId)
       }
+      tsm.abort("Stage %s cancelled".format(stageId))
       logInfo("Stage %d was cancelled".format(stageId))
-      tsm.removeAllRunningTasks()
-      taskSetFinished(tsm)
     }
   }
 
+  /**
+   * Called to indicate that all task attempts (including speculated tasks) associated with the
+   * given TaskSetManager have completed, so state associated with the TaskSetManager should be
+   * cleaned up.
+   */
   def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
-    // Check to see if the given task set has been removed. This is possible in the case of
-    // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
-    // more than one running tasks).
-    if (activeTaskSets.contains(manager.taskSet.id)) {
-      activeTaskSets -= manager.taskSet.id
-      manager.parent.removeSchedulable(manager)
-      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
-      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
-      taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
-      taskSetTaskIds.remove(manager.taskSet.id)
-    }
+    activeTaskSets -= manager.taskSet.id
+    manager.parent.removeSchedulable(manager)
+    logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
+      .format(manager.taskSet.id, manager.parent.name))
   }
 
   /**
@@ -237,7 +229,6 @@ private[spark] class TaskSchedulerImpl(
             tasks(i) += task
             val tid = task.taskId
             taskIdToTaskSetId(tid) = taskSet.taskSet.id
-            taskSetTaskIds(taskSet.taskSet.id) += tid
             taskIdToExecutorId(tid) = execId
             activeExecutorIds += execId
             executorsByHost(host) += execId
@@ -270,9 +261,6 @@ private[spark] class TaskSchedulerImpl(
           case Some(taskSetId) =>
             if (TaskState.isFinished(state)) {
               taskIdToTaskSetId.remove(tid)
-              if (taskSetTaskIds.contains(taskSetId)) {
-                taskSetTaskIds(taskSetId) -= tid
-              }
               taskIdToExecutorId.remove(tid)
             }
             activeTaskSets.get(taskSetId).foreach { taskSet =>
@@ -285,7 +273,9 @@ private[spark] class TaskSchedulerImpl(
               }
             }
           case None =>
-            logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+            logError(
+              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
+               "likely the result of receiving duplicate task finished status updates)")
               .format(state, tid))
         }
       } catch {
@@ -314,9 +304,9 @@ private[spark] class TaskSchedulerImpl(
     taskSetManager: TaskSetManager,
     tid: Long,
     taskState: TaskState,
-    reason: Option[TaskEndReason]) = synchronized {
+    reason: TaskEndReason) = synchronized {
     taskSetManager.handleFailedTask(tid, taskState, reason)
-    if (taskState != TaskState.KILLED) {
+    if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
       // Need to revive offers again now that the task set manager state has been updated to
       // reflect failed tasks that need to be re-run.
       backend.reviveOffers()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 777f31d..3f0ee7a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -26,9 +26,10 @@ import scala.collection.mutable.HashSet
 import scala.math.max
 import scala.math.min
 
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
-  Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
+import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted,
+  SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
 import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.util.{Clock, SystemClock}
 
 
@@ -82,8 +83,16 @@ private[spark] class TaskSetManager(
   var name = "TaskSet_"+taskSet.stageId.toString
   var parent: Pool = null
 
-  var runningTasks = 0
-  private val runningTasksSet = new HashSet[Long]
+  val runningTasksSet = new HashSet[Long]
+  override def runningTasks = runningTasksSet.size
+
+  // True once no more tasks should be launched for this task set manager. TaskSetManagers enter
+  // the zombie state once at least one attempt of each task has completed successfully, or if the
+  // task set is aborted (for example, because it was killed).  TaskSetManagers remain in the zombie
+  // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie
+  // state in order to continue to track and account for the running tasks.
+  // TODO: We should kill any running task attempts when the task set manager becomes a zombie.
+  var isZombie = false
 
   // Set of pending tasks for each executor. These collections are actually
   // treated as stacks, in which new tasks are added to the end of the
@@ -345,7 +354,7 @@ private[spark] class TaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+    if (!isZombie && availableCpus >= CPUS_PER_TASK) {
       val curTime = clock.getTime()
 
       var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -380,8 +389,7 @@ private[spark] class TaskSetManager(
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
           val taskName = "task %s:%d".format(taskSet.id, index)
-          if (taskAttempts(index).size == 1)
-            taskStarted(task,info)
+          sched.dagScheduler.taskStarted(task, info)
           return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
         }
         case _ =>
@@ -390,6 +398,12 @@ private[spark] class TaskSetManager(
     None
   }
 
+  private def maybeFinishTaskSet() {
+    if (isZombie && runningTasks == 0) {
+      sched.taskSetFinished(this)
+    }
+  }
+
   /**
    * Get the level we can launch tasks according to delay scheduling, based on current wait time.
    */
@@ -418,10 +432,6 @@ private[spark] class TaskSetManager(
     index
   }
 
-  private def taskStarted(task: Task[_], info: TaskInfo) {
-    sched.dagScheduler.taskStarted(task, info)
-  }
-
   def handleTaskGettingResult(tid: Long) = {
     val info = taskInfos(tid)
     info.markGettingResult()
@@ -436,123 +446,116 @@ private[spark] class TaskSetManager(
     val index = info.index
     info.markSuccessful()
     removeRunningTask(tid)
+    sched.dagScheduler.taskEnded(
+      tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
     if (!successful(index)) {
       tasksSuccessful += 1
       logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
         tid, info.duration, info.host, tasksSuccessful, numTasks))
-      sched.dagScheduler.taskEnded(
-        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
       // Mark successful and stop if all the tasks have succeeded.
       successful(index) = true
       if (tasksSuccessful == numTasks) {
-        sched.taskSetFinished(this)
+        isZombie = true
       }
     } else {
       logInfo("Ignorning task-finished event for TID " + tid + " because task " +
         index + " has already completed successfully")
     }
+    maybeFinishTaskSet()
   }
 
   /**
    * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
    * DAG Scheduler.
    */
-  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+  def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) {
     val info = taskInfos(tid)
     if (info.failed) {
       return
     }
     removeRunningTask(tid)
-    val index = info.index
     info.markFailed()
-    var failureReason = "unknown"
-    if (!successful(index)) {
+    val index = info.index
+    copiesRunning(index) -= 1
+    if (!isZombie) {
       logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
-      copiesRunning(index) -= 1
-      // Check if the problem is a map output fetch failure. In that case, this
-      // task will never succeed on any node, so tell the scheduler about it.
-      reason.foreach {
-        case fetchFailed: FetchFailed =>
-          logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
-          sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+    }
+    var taskMetrics : TaskMetrics = null
+    var failureReason = "unknown"
+    reason match {
+      case fetchFailed: FetchFailed =>
+        logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+        if (!successful(index)) {
           successful(index) = true
           tasksSuccessful += 1
-          sched.taskSetFinished(this)
-          removeAllRunningTasks()
-          return
-
-        case TaskKilled =>
-          logWarning("Task %d was killed.".format(tid))
-          sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+        }
+        isZombie = true
+
+      case TaskKilled =>
+        logWarning("Task %d was killed.".format(tid))
+
+      case ef: ExceptionFailure =>
+        taskMetrics = ef.metrics.getOrElse(null)
+        if (ef.className == classOf[NotSerializableException].getName()) {
+          // If the task result wasn't serializable, there's no point in trying to re-execute it.
+          logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+            taskSet.id, index, ef.description))
+          abort("Task %s:%s had a not serializable result: %s".format(
+            taskSet.id, index, ef.description))
           return
-
-        case ef: ExceptionFailure =>
-          sched.dagScheduler.taskEnded(
-            tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
-          if (ef.className == classOf[NotSerializableException].getName()) {
-            // If the task result wasn't rerializable, there's no point in trying to re-execute it.
-            logError("Task %s:%s had a not serializable result: %s; not retrying".format(
-              taskSet.id, index, ef.description))
-            abort("Task %s:%s had a not serializable result: %s".format(
-              taskSet.id, index, ef.description))
-            return
-          }
-          val key = ef.description
-          failureReason = "Exception failure: %s".format(ef.description)
-          val now = clock.getTime()
-          val (printFull, dupCount) = {
-            if (recentExceptions.contains(key)) {
-              val (dupCount, printTime) = recentExceptions(key)
-              if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
-                recentExceptions(key) = (0, now)
-                (true, 0)
-              } else {
-                recentExceptions(key) = (dupCount + 1, printTime)
-                (false, dupCount + 1)
-              }
-            } else {
+        }
+        val key = ef.description
+        failureReason = "Exception failure: %s".format(ef.description)
+        val now = clock.getTime()
+        val (printFull, dupCount) = {
+          if (recentExceptions.contains(key)) {
+            val (dupCount, printTime) = recentExceptions(key)
+            if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
               recentExceptions(key) = (0, now)
               (true, 0)
+            } else {
+              recentExceptions(key) = (dupCount + 1, printTime)
+              (false, dupCount + 1)
             }
-          }
-          if (printFull) {
-            val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
-            logWarning("Loss was due to %s\n%s\n%s".format(
-              ef.className, ef.description, locs.mkString("\n")))
           } else {
-            logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+            recentExceptions(key) = (0, now)
+            (true, 0)
           }
+        }
+        if (printFull) {
+          val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+          logWarning("Loss was due to %s\n%s\n%s".format(
+            ef.className, ef.description, locs.mkString("\n")))
+        } else {
+          logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+        }
 
-        case TaskResultLost =>
-          failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
-          logWarning(failureReason)
-          sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+      case TaskResultLost =>
+        failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+        logWarning(failureReason)
 
-        case _ => {}
-      }
-      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
-      addPendingTask(index)
-      if (state != TaskState.KILLED) {
-        numFailures(index) += 1
-        if (numFailures(index) >= maxTaskFailures) {
-          logError("Task %s:%d failed %d times; aborting job".format(
-            taskSet.id, index, maxTaskFailures))
-          abort("Task %s:%d failed %d times (most recent failure: %s)".format(
-            taskSet.id, index, maxTaskFailures, failureReason))
-        }
+      case _ => {}
+    }
+    sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
+    addPendingTask(index)
+    if (!isZombie && state != TaskState.KILLED) {
+      numFailures(index) += 1
+      if (numFailures(index) >= maxTaskFailures) {
+        logError("Task %s:%d failed %d times; aborting job".format(
+          taskSet.id, index, maxTaskFailures))
+        abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+          taskSet.id, index, maxTaskFailures, failureReason))
+        return
       }
-    } else {
-      logInfo("Ignoring task-lost event for TID " + tid +
-        " because task " + index + " is already finished")
     }
+    maybeFinishTaskSet()
   }
 
   def abort(message: String) {
     // TODO: Kill running tasks if we were not terminated due to a Mesos error
     sched.dagScheduler.taskSetFailed(taskSet, message)
-    removeAllRunningTasks()
-    sched.taskSetFinished(this)
+    isZombie = true
+    maybeFinishTaskSet()
   }
 
   /** If the given task ID is not in the set of running tasks, adds it.
@@ -563,7 +566,6 @@ private[spark] class TaskSetManager(
     if (runningTasksSet.add(tid) && parent != null) {
       parent.increaseRunningTasks(1)
     }
-    runningTasks = runningTasksSet.size
   }
 
   /** If the given task ID is in the set of running tasks, removes it. */
@@ -571,16 +573,6 @@ private[spark] class TaskSetManager(
     if (runningTasksSet.remove(tid) && parent != null) {
       parent.decreaseRunningTasks(1)
     }
-    runningTasks = runningTasksSet.size
-  }
-
-  private[scheduler] def removeAllRunningTasks() {
-    val numRunningTasks = runningTasksSet.size
-    runningTasksSet.clear()
-    if (parent != null) {
-      parent.decreaseRunningTasks(numRunningTasks)
-    }
-    runningTasks = 0
   }
 
   override def getSchedulableByName(name: String): Schedulable = {
@@ -629,7 +621,7 @@ private[spark] class TaskSetManager(
     }
     // Also re-enqueue any tasks that were running on the node
     for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
-      handleFailedTask(tid, TaskState.FAILED, None)
+      handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
     }
   }
 
@@ -641,8 +633,9 @@ private[spark] class TaskSetManager(
    * we don't scan the whole task set. It might also help to make this sorted by launch time.
    */
   override def checkSpeculatableTasks(): Boolean = {
-    // Can't speculate if we only have one task, or if all tasks have finished.
-    if (numTasks == 1 || tasksSuccessful == numTasks) {
+    // Can't speculate if we only have one task, and no need to speculate if the task set is a
+    // zombie.
+    if (isZombie || numTasks == 1) {
       return false
     }
     var foundTasks = false

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 235d317..98ea4cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -36,22 +36,24 @@ class FakeTaskSetManager(
   parent = null
   weight = 1
   minShare = 2
-  runningTasks = 0
   priority = initPriority
   stageId = initStageId
   name = "TaskSet_"+stageId
   override val numTasks = initNumTasks
   tasksSuccessful = 0
 
+  var numRunningTasks = 0
+  override def runningTasks = numRunningTasks
+
   def increaseRunningTasks(taskNum: Int) {
-    runningTasks += taskNum
+    numRunningTasks += taskNum
     if (parent != null) {
       parent.increaseRunningTasks(taskNum)
     }
   }
 
   def decreaseRunningTasks(taskNum: Int) {
-    runningTasks -= taskNum
+    numRunningTasks -= taskNum
     if (parent != null) {
       parent.decreaseRunningTasks(taskNum)
     }
@@ -77,7 +79,7 @@ class FakeTaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksSuccessful + runningTasks < numTasks) {
+    if (tasksSuccessful + numRunningTasks < numTasks) {
       increaseRunningTasks(1)
       Some(new TaskDescription(0, execId, "task 0:0", 0, null))
     } else {
@@ -98,7 +100,7 @@ class FakeTaskSetManager(
   }
 
   def abort() {
-    decreaseRunningTasks(runningTasks)
+    decreaseRunningTasks(numRunningTasks)
     parent.removeSchedulable(this)
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1a16e43..368c515 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -168,6 +168,39 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     assert(listener.endedTasks.contains(TASK_INDEX))
   }
 
+  test("onTaskEnd() should be called for all started tasks, even after job has been killed") {
+    val WAIT_TIMEOUT_MILLIS = 10000
+    val listener = new SaveTaskEvents
+    sc.addSparkListener(listener)
+
+    val numTasks = 10
+    val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync()
+    // Wait until one task has started (because we want to make sure that any tasks that are started
+    // have corresponding end events sent to the listener).
+    var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+    listener.synchronized {
+      var remainingWait = finishTime - System.currentTimeMillis
+      while (listener.startedTasks.isEmpty && remainingWait > 0) {
+        listener.wait(remainingWait)
+        remainingWait = finishTime - System.currentTimeMillis
+      }
+      assert(!listener.startedTasks.isEmpty)
+    }
+
+    f.cancel()
+
+    // Ensure that onTaskEnd is called for all started tasks.
+    finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+    listener.synchronized {
+      var remainingWait = finishTime - System.currentTimeMillis
+      while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) {
+        listener.wait(finishTime - System.currentTimeMillis)
+        remainingWait = finishTime - System.currentTimeMillis
+      }
+      assert(listener.endedTasks.size === listener.startedTasks.size)
+    }
+  }
+
   def checkNonZeroAvg(m: Traversable[Long], msg: String) {
     assert(m.sum / m.size.toDouble > 0.0, msg)
   }
@@ -184,12 +217,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     val startedGettingResultTasks = new HashSet[Int]()
     val endedTasks = new HashSet[Int]()
 
-    override def onTaskStart(taskStart: SparkListenerTaskStart) {
+    override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
       startedTasks += taskStart.taskInfo.index
+      notify()
     }
 
-    override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-        endedTasks += taskEnd.taskInfo.index
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+      endedTasks += taskEnd.taskInfo.index
+      notify()
     }
 
     override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/18ad59e2/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecac2f7..de321c4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -269,7 +269,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
 
     // Tell it the task has finished but the result was lost.
-    manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
+    manager.handleFailedTask(0, TaskState.FINISHED, TaskResultLost)
     assert(sched.endedTasks(0) === TaskResultLost)
 
     // Re-offer the host -- now we should get task 0 again.
@@ -290,7 +290,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
       assert(offerResult.isDefined,
         "Expect resource offer on iteration %s to return a task".format(index))
       assert(offerResult.get.index === 0)
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
       if (index < MAX_TASK_FAILURES) {
         assert(!sched.taskSetsFailed.contains(taskSet.id))
       } else {