You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/26 03:45:07 UTC

git commit: Part of [SPARK-2456] Removed some HashMaps from DAGScheduler by storing information in Stage.

Repository: spark
Updated Branches:
  refs/heads/master afd757a24 -> 9d8666cac


Part of [SPARK-2456] Removed some HashMaps from DAGScheduler by storing information in Stage.

This is part of the scheduler cleanup/refactoring effort to make the scheduler code easier to maintain.

@kayousterhout @markhamstra please take a look ...

Author: Reynold Xin <rx...@apache.org>

Closes #1561 from rxin/dagSchedulerHashMaps and squashes the following commits:

1c44e15 [Reynold Xin] Clear pending tasks in submitMissingTasks.
620a0d1 [Reynold Xin] Use filterKeys.
5b54404 [Reynold Xin] Code review feedback.
c1e9a1c [Reynold Xin] Removed some HashMaps from DAGScheduler by storing information in Stage.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d8666ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d8666ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d8666ca

Branch: refs/heads/master
Commit: 9d8666cac84fc4fc867f6a5e80097dbe5cb65301
Parents: afd757a
Author: Reynold Xin <rx...@apache.org>
Authored: Fri Jul 25 18:45:02 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Fri Jul 25 18:45:02 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/scheduler/DAGScheduler.scala   | 143 +++++++------------
 .../org/apache/spark/scheduler/Stage.scala      |  19 ++-
 .../spark/scheduler/DAGSchedulerSuite.scala     |   4 -
 3 files changed, 69 insertions(+), 97 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d8666ca/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 00b8af2..dc6142a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -85,12 +85,9 @@ class DAGScheduler(
   private val nextStageId = new AtomicInteger(0)
 
   private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
-  private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
   private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
   private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
   private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
-  private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
-  private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]
 
   // Stages we need to run whose parents aren't done
   private[scheduler] val waitingStages = new HashSet[Stage]
@@ -101,9 +98,6 @@ class DAGScheduler(
   // Stages that must be resubmitted due to fetch failures
   private[scheduler] val failedStages = new HashSet[Stage]
 
-  // Missing tasks from each stage
-  private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
-
   private[scheduler] val activeJobs = new HashSet[ActiveJob]
 
   // Contains the locations that each RDD's partitions are cached on
@@ -223,7 +217,6 @@ class DAGScheduler(
       new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
     stageIdToStage(id) = stage
     updateJobIdStageIdMaps(jobId, stage)
-    stageToInfos(stage) = StageInfo.fromStage(stage)
     stage
   }
 
@@ -315,13 +308,12 @@ class DAGScheduler(
    */
   private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
     def updateJobIdStageIdMapsList(stages: List[Stage]) {
-      if (!stages.isEmpty) {
+      if (stages.nonEmpty) {
         val s = stages.head
-        stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
+        s.jobIds += jobId
         jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
-        val parents = getParentStages(s.rdd, jobId)
-        val parentsWithoutThisJobId = parents.filter(p =>
-          !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+        val parents: List[Stage] = getParentStages(s.rdd, jobId)
+        val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
         updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
       }
     }
@@ -333,16 +325,15 @@ class DAGScheduler(
    * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
    *
    * @param job The job whose state to cleanup.
-   * @param resultStage Specifies the result stage for the job; if set to None, this method
-   *                    searches resultStagesToJob to find and cleanup the appropriate result stage.
    */
-  private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
+  private def cleanupStateForJobAndIndependentStages(job: ActiveJob) {
     val registeredStages = jobIdToStageIds.get(job.jobId)
     if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
       logError("No stages registered for job " + job.jobId)
     } else {
-      stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
-        case (stageId, jobSet) =>
+      stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
+        case (stageId, stage) =>
+          val jobSet = stage.jobIds
           if (!jobSet.contains(job.jobId)) {
             logError(
               "Job %d not registered for stage %d even though that stage was registered for the job"
@@ -355,14 +346,9 @@ class DAGScheduler(
                   logDebug("Removing running stage %d".format(stageId))
                   runningStages -= stage
                 }
-                stageToInfos -= stage
                 for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
                   shuffleToMapStage.remove(k)
                 }
-                if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) {
-                  logDebug("Removing pending status for stage %d".format(stageId))
-                }
-                pendingTasks -= stage
                 if (waitingStages.contains(stage)) {
                   logDebug("Removing stage %d from waiting set.".format(stageId))
                   waitingStages -= stage
@@ -374,7 +360,6 @@ class DAGScheduler(
               }
               // data structures based on StageId
               stageIdToStage -= stageId
-              stageIdToJobIds -= stageId
 
               ShuffleMapTask.removeStage(stageId)
               ResultTask.removeStage(stageId)
@@ -393,19 +378,7 @@ class DAGScheduler(
     jobIdToStageIds -= job.jobId
     jobIdToActiveJob -= job.jobId
     activeJobs -= job
-
-    if (resultStage.isEmpty) {
-      // Clean up result stages.
-      val resultStagesForJob = resultStageToJob.keySet.filter(
-        stage => resultStageToJob(stage).jobId == job.jobId)
-      if (resultStagesForJob.size != 1) {
-        logWarning(
-          s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
-      }
-      resultStageToJob --= resultStagesForJob
-    } else {
-      resultStageToJob -= resultStage.get
-    }
+    job.finalStage.resultOfJob = None
   }
 
   /**
@@ -591,9 +564,10 @@ class DAGScheduler(
         job.listener.jobFailed(exception)
     } finally {
       val s = job.finalStage
-      stageIdToJobIds -= s.id    // clean up data structures that were populated for a local job,
-      stageIdToStage -= s.id     // but that won't get cleaned up via the normal paths through
-      stageToInfos -= s          // completion events or stage abort
+      // clean up data structures that were populated for a local job,
+      // but that won't get cleaned up via the normal paths through
+      // completion events or stage abort
+      stageIdToStage -= s.id
       jobIdToStageIds -= job.jobId
       listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
     }
@@ -605,12 +579,8 @@ class DAGScheduler(
   // That should take care of at least part of the priority inversion problem with
   // cross-job dependencies.
   private def activeJobForStage(stage: Stage): Option[Int] = {
-    if (stageIdToJobIds.contains(stage.id)) {
-      val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
-      jobsThatUseStage.find(jobIdToActiveJob.contains)
-    } else {
-      None
-    }
+    val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted
+    jobsThatUseStage.find(jobIdToActiveJob.contains)
   }
 
   private[scheduler] def handleJobGroupCancelled(groupId: String) {
@@ -642,9 +612,8 @@ class DAGScheduler(
       // is in the process of getting stopped.
       val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
       runningStages.foreach { stage =>
-        val info = stageToInfos(stage)
-        info.stageFailed(stageFailedMessage)
-        listenerBus.post(SparkListenerStageCompleted(info))
+        stage.info.stageFailed(stageFailedMessage)
+        listenerBus.post(SparkListenerStageCompleted(stage.info))
       }
       listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
     }
@@ -690,7 +659,7 @@ class DAGScheduler(
       } else {
         jobIdToActiveJob(jobId) = job
         activeJobs += job
-        resultStageToJob(finalStage) = job
+        finalStage.resultOfJob = Some(job)
         listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
           properties))
         submitStage(finalStage)
@@ -727,8 +696,7 @@ class DAGScheduler(
   private def submitMissingTasks(stage: Stage, jobId: Int) {
     logDebug("submitMissingTasks(" + stage + ")")
     // Get our pending tasks and remember them in our pendingTasks entry
-    val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
-    myPending.clear()
+    stage.pendingTasks.clear()
     var tasks = ArrayBuffer[Task[_]]()
     if (stage.isShuffleMap) {
       for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@@ -737,7 +705,7 @@ class DAGScheduler(
       }
     } else {
       // This is a final stage; figure out its job's missing partitions
-      val job = resultStageToJob(stage)
+      val job = stage.resultOfJob.get
       for (id <- 0 until job.numPartitions if !job.finished(id)) {
         val partition = job.partitions(id)
         val locs = getPreferredLocs(stage.rdd, partition)
@@ -758,7 +726,7 @@ class DAGScheduler(
       // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
       // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
       // event.
-      listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
+      listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
 
       // Preemptively serialize a task to make sure it can be serialized. We are catching this
       // exception here because it would be fairly hard to catch the non-serializable exception
@@ -778,11 +746,11 @@ class DAGScheduler(
       }
 
       logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
-      myPending ++= tasks
-      logDebug("New pending tasks: " + myPending)
+      stage.pendingTasks ++= tasks
+      logDebug("New pending tasks: " + stage.pendingTasks)
       taskScheduler.submitTasks(
         new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
-      stageToInfos(stage).submissionTime = Some(clock.getTime())
+      stage.info.submissionTime = Some(clock.getTime())
     } else {
       logDebug("Stage " + stage + " is actually done; %b %d %d".format(
         stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -807,13 +775,13 @@ class DAGScheduler(
     val stage = stageIdToStage(task.stageId)
 
     def markStageAsFinished(stage: Stage) = {
-      val serviceTime = stageToInfos(stage).submissionTime match {
+      val serviceTime = stage.info.submissionTime match {
         case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
         case _ => "Unknown"
       }
       logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
-      stageToInfos(stage).completionTime = Some(clock.getTime())
-      listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
+      stage.info.completionTime = Some(clock.getTime())
+      listenerBus.post(SparkListenerStageCompleted(stage.info))
       runningStages -= stage
     }
     event.reason match {
@@ -822,10 +790,10 @@ class DAGScheduler(
           // TODO: fail the stage if the accumulator update fails...
           Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
         }
-        pendingTasks(stage) -= task
+        stage.pendingTasks -= task
         task match {
           case rt: ResultTask[_, _] =>
-            resultStageToJob.get(stage) match {
+            stage.resultOfJob match {
               case Some(job) =>
                 if (!job.finished(rt.outputId)) {
                   job.finished(rt.outputId) = true
@@ -833,7 +801,7 @@ class DAGScheduler(
                   // If the whole job has finished, remove it
                   if (job.numFinished == job.numPartitions) {
                     markStageAsFinished(stage)
-                    cleanupStateForJobAndIndependentStages(job, Some(stage))
+                    cleanupStateForJobAndIndependentStages(job)
                     listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
                   }
 
@@ -860,7 +828,7 @@ class DAGScheduler(
             } else {
               stage.addOutputLoc(smt.partitionId, status)
             }
-            if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
+            if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
               markStageAsFinished(stage)
               logInfo("looking for newly runnable stages")
               logInfo("running: " + runningStages)
@@ -909,7 +877,7 @@ class DAGScheduler(
 
       case Resubmitted =>
         logInfo("Resubmitted " + task + ", so marking it as still running")
-        pendingTasks(stage) += task
+        stage.pendingTasks += task
 
       case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
         // Mark the stage that the reducer was in as unrunnable
@@ -994,13 +962,14 @@ class DAGScheduler(
   }
 
   private[scheduler] def handleStageCancellation(stageId: Int) {
-    if (stageIdToJobIds.contains(stageId)) {
-      val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray
-      jobsThatUseStage.foreach(jobId => {
-        handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId))
-      })
-    } else {
-      logInfo("No active jobs to kill for Stage " + stageId)
+    stageIdToStage.get(stageId) match {
+      case Some(stage) =>
+        val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
+        jobsThatUseStage.foreach { jobId =>
+          handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
+        }
+      case None =>
+        logInfo("No active jobs to kill for Stage " + stageId)
     }
     submitWaitingStages()
   }
@@ -1009,8 +978,8 @@ class DAGScheduler(
     if (!jobIdToStageIds.contains(jobId)) {
       logDebug("Trying to cancel unregistered job " + jobId)
     } else {
-      failJobAndIndependentStages(jobIdToActiveJob(jobId),
-        "Job %d cancelled %s".format(jobId, reason), None)
+      failJobAndIndependentStages(
+        jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
     }
     submitWaitingStages()
   }
@@ -1024,26 +993,21 @@ class DAGScheduler(
       // Skip all the actions if the stage has been removed.
       return
     }
-    val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
-    stageToInfos(failedStage).completionTime = Some(clock.getTime())
-    for (resultStage <- dependentStages) {
-      val job = resultStageToJob(resultStage)
-      failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
-        Some(resultStage))
+    val dependentJobs: Seq[ActiveJob] =
+      activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
+    failedStage.info.completionTime = Some(clock.getTime())
+    for (job <- dependentJobs) {
+      failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
     }
-    if (dependentStages.isEmpty) {
+    if (dependentJobs.isEmpty) {
       logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
     }
   }
 
   /**
    * Fails a job and all stages that are only used by that job, and cleans up relevant state.
-   *
-   * @param resultStage The result stage for the job, if known. Used to cleanup state for the job
-   *                    slightly more efficiently than when not specified.
    */
-  private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
-      resultStage: Option[Stage]) {
+  private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
     val error = new SparkException(failureReason)
     var ableToCancelStages = true
 
@@ -1057,7 +1021,7 @@ class DAGScheduler(
       logError("No stages registered for job " + job.jobId)
     }
     stages.foreach { stageId =>
-      val jobsForStage = stageIdToJobIds.get(stageId)
+      val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds)
       if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
         logError(
           "Job %d not registered for stage %d even though that stage was registered for the job"
@@ -1071,9 +1035,8 @@ class DAGScheduler(
           if (runningStages.contains(stage)) {
             try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
               taskScheduler.cancelTasks(stageId, shouldInterruptThread)
-              val stageInfo = stageToInfos(stage)
-              stageInfo.stageFailed(failureReason)
-              listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
+              stage.info.stageFailed(failureReason)
+              listenerBus.post(SparkListenerStageCompleted(stage.info))
             } catch {
               case e: UnsupportedOperationException =>
                 logInfo(s"Could not cancel tasks for stage $stageId", e)
@@ -1086,7 +1049,7 @@ class DAGScheduler(
 
     if (ableToCancelStages) {
       job.listener.jobFailed(error)
-      cleanupStateForJobAndIndependentStages(job, resultStage)
+      cleanupStateForJobAndIndependentStages(job)
       listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d8666ca/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 798cbc5..8009054 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.scheduler
 
+import scala.collection.mutable.HashSet
+
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.BlockManagerId
@@ -56,8 +58,22 @@ private[spark] class Stage(
   val numPartitions = rdd.partitions.size
   val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
   var numAvailableOutputs = 0
+
+  /** Set of jobs that this stage belongs to. */
+  val jobIds = new HashSet[Int]
+
+  /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */
+  var resultOfJob: Option[ActiveJob] = None
+  var pendingTasks = new HashSet[Task[_]]
+
   private var nextAttemptId = 0
 
+  val name = callSite.shortForm
+  val details = callSite.longForm
+
+  /** Pointer to the [StageInfo] object, set by DAGScheduler. */
+  var info: StageInfo = StageInfo.fromStage(this)
+
   def isAvailable: Boolean = {
     if (!isShuffleMap) {
       true
@@ -108,9 +124,6 @@ private[spark] class Stage(
 
   def attemptId: Int = nextAttemptId
 
-  val name = callSite.shortForm
-  val details = callSite.longForm
-
   override def toString = "Stage " + id
 
   override def hashCode(): Int = id

http://git-wip-us.apache.org/repos/asf/spark/blob/9d8666ca/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 44dd1e0..9021662 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -686,15 +686,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
     BlockManagerId("exec-" + host, host, 12345, 0)
 
   private def assertDataStructuresEmpty = {
-    assert(scheduler.pendingTasks.isEmpty)
     assert(scheduler.activeJobs.isEmpty)
     assert(scheduler.failedStages.isEmpty)
     assert(scheduler.jobIdToActiveJob.isEmpty)
     assert(scheduler.jobIdToStageIds.isEmpty)
-    assert(scheduler.stageIdToJobIds.isEmpty)
     assert(scheduler.stageIdToStage.isEmpty)
-    assert(scheduler.stageToInfos.isEmpty)
-    assert(scheduler.resultStageToJob.isEmpty)
     assert(scheduler.runningStages.isEmpty)
     assert(scheduler.shuffleToMapStage.isEmpty)
     assert(scheduler.waitingStages.isEmpty)