You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/04/23 04:24:07 UTC

git commit: [Spark-1538] Fix SparkUI incorrectly hiding persisted RDDs

Repository: spark
Updated Branches:
  refs/heads/master 995fdc96b -> 2de573877


[Spark-1538] Fix SparkUI incorrectly hiding persisted RDDs

**Bug**: After the following command `sc.parallelize(1 to 1000).persist.map(_ + 1).count()` is run, the the persisted RDD is missing from the storage tab of the SparkUI.

**Cause**: The command creates two RDDs in one stage, a `ParallelCollectionRDD` and a `MappedRDD`. However, the existing StageInfo only keeps the RDDInfo of the last RDD associated with the stage (`MappedRDD`), and so all RDD information regarding the first RDD (`ParallelCollectionRDD`) is discarded. In this case, we persist the first RDD,  but the StorageTab doesn't know about this RDD because it is not encoded in the StageInfo.

**Fix**: Record information of all RDDs in StageInfo, instead of just the last RDD (i.e. `stage.rdd`). Since stage boundaries are marked by shuffle dependencies, the solution is to traverse the last RDD's dependency tree, visiting only ancestor RDDs related through a sequence of narrow dependencies.

---

This PR also moves RDDInfo to its own file, includes a few style fixes, and adds a unit test for constructing StageInfos.

Author: Andrew Or <an...@gmail.com>

Closes #469 from andrewor14/storage-ui-fix and squashes the following commits:

07fc7f0 [Andrew Or] Add back comment that was accidentally removed (minor)
5d799fe [Andrew Or] Add comment to justify testing of getNarrowAncestors with cycles
9d0e2b8 [Andrew Or] Hide details of getNarrowAncestors from outsiders
d2bac8a [Andrew Or] Deal with cycles in RDD dependency graph + add extensive tests
2acb177 [Andrew Or] Move getNarrowAncestors to RDD.scala
bfe83f0 [Andrew Or] Backtrace RDD dependency tree to find all RDDs that belong to a Stage


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2de57387
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2de57387
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2de57387

Branch: refs/heads/master
Commit: 2de573877fbed20092f1b3af20b603b30ba9a940
Parents: 995fdc9
Author: Andrew Or <an...@gmail.com>
Authored: Tue Apr 22 19:24:03 2014 -0700
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Tue Apr 22 19:24:03 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    |   2 +-
 .../org/apache/spark/executor/TaskMetrics.scala |   2 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  25 ++++
 .../org/apache/spark/scheduler/JobLogger.scala  |   2 +-
 .../org/apache/spark/scheduler/StageInfo.scala  |  19 ++-
 .../org/apache/spark/storage/RDDInfo.scala      |  55 +++++++
 .../org/apache/spark/storage/StorageUtils.scala |  44 +-----
 .../apache/spark/ui/exec/ExecutorsPage.scala    |   6 +-
 .../apache/spark/ui/storage/StorageTab.scala    |   4 +-
 .../org/apache/spark/util/JsonProtocol.scala    |  22 +--
 .../org/apache/spark/CacheManagerSuite.scala    |   6 +-
 .../scala/org/apache/spark/PipedRDDSuite.scala  |   2 +-
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 148 ++++++++++++++++++-
 .../spark/scheduler/SparkListenerSuite.scala    |  54 ++++++-
 .../apache/spark/util/JsonProtocolSuite.scala   |   8 +-
 15 files changed, 318 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index dc5a19e..dc012cc 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -33,7 +33,7 @@ class TaskContext(
   val attemptId: Long,
   val runningLocally: Boolean = false,
   @volatile var interrupted: Boolean = false,
-  private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
+  private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty
 ) extends Serializable {
 
   @deprecated("use partitionId", "0.8.1")

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index e4f02a4..350fd74 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -84,7 +84,7 @@ class TaskMetrics extends Serializable {
 }
 
 private[spark] object TaskMetrics {
-  def empty(): TaskMetrics = new TaskMetrics
+  def empty: TaskMetrics = new TaskMetrics
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 5d2ed2b..596dcb8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
 import java.util.Random
 
 import scala.collection.Map
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.{classTag, ClassTag}
 
@@ -230,6 +231,30 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
+   * Return the ancestors of the given RDD that are related to it only through a sequence of
+   * narrow dependencies. This traverses the given RDD's dependency tree using DFS, but maintains
+   * no ordering on the RDDs returned.
+   */
+  private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
+    val ancestors = new mutable.HashSet[RDD[_]]
+
+    def visit(rdd: RDD[_]) {
+      val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
+      val narrowParents = narrowDependencies.map(_.rdd)
+      val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
+      narrowParentsNotVisited.foreach { parent =>
+        ancestors.add(parent)
+        visit(parent)
+      }
+    }
+
+    visit(this)
+
+    // In case there is a cycle, do not include the root itself
+    ancestors.filterNot(_ == this).toSeq
+  }
+
+  /**
    * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
    */
   private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 713aebf..a1e21ca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -207,7 +207,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
     val taskInfo = taskEnd.taskInfo
     var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType)
-    val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty()
+    val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty
     taskEnd.reason match {
       case Success => taskStatus += " STATUS=SUCCESS"
         recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics)

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 9f732f7..b42e231 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -25,7 +25,7 @@ import org.apache.spark.storage.RDDInfo
  * Stores information about a stage to pass from the scheduler to SparkListeners.
  */
 @DeveloperApi
-class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) {
+class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo]) {
   /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
   var submissionTime: Option[Long] = None
   /** Time when all tasks in the stage completed or when the stage was cancelled. */
@@ -41,12 +41,17 @@ class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddIn
   }
 }
 
-private[spark]
-object StageInfo {
+private[spark] object StageInfo {
+  /**
+   * Construct a StageInfo from a Stage.
+   *
+   * Each Stage is associated with one or many RDDs, with the boundary of a Stage marked by
+   * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
+   * sequence of narrow dependencies should also be associated with this Stage.
+   */
   def fromStage(stage: Stage): StageInfo = {
-    val rdd = stage.rdd
-    val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
-    val rddInfo = new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel)
-    new StageInfo(stage.id, stage.name, stage.numTasks, rddInfo)
+    val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
+    val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
+    new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
new file mode 100644
index 0000000..023fd6e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+@DeveloperApi
+class RDDInfo(
+    val id: Int,
+    val name: String,
+    val numPartitions: Int,
+    val storageLevel: StorageLevel)
+  extends Ordered[RDDInfo] {
+
+  var numCachedPartitions = 0
+  var memSize = 0L
+  var diskSize = 0L
+  var tachyonSize = 0L
+
+  override def toString = {
+    import Utils.bytesToString
+    ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " +
+      "TachyonSize: %s; DiskSize: %s").format(
+        name, id, storageLevel.toString, numCachedPartitions, numPartitions,
+        bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
+  }
+
+  override def compare(that: RDDInfo) = {
+    this.id - that.id
+  }
+}
+
+private[spark] object RDDInfo {
+  def fromRdd(rdd: RDD[_]): RDDInfo = {
+    val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
+    new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 7ed3713..1eddd1c 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -21,60 +21,30 @@ import scala.collection.Map
 import scala.collection.mutable
 
 import org.apache.spark.SparkContext
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.Utils
 
-private[spark]
-class StorageStatus(
+/** Storage information for each BlockManager. */
+private[spark] class StorageStatus(
     val blockManagerId: BlockManagerId,
     val maxMem: Long,
     val blocks: mutable.Map[BlockId, BlockStatus] = mutable.Map.empty) {
 
-  def memUsed() = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
+  def memUsed = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
 
   def memUsedByRDD(rddId: Int) =
     rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
 
-  def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
+  def diskUsed = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
 
   def diskUsedByRDD(rddId: Int) =
     rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
 
-  def memRemaining : Long = maxMem - memUsed()
+  def memRemaining: Long = maxMem - memUsed
 
   def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) }
 }
 
-@DeveloperApi
-private[spark]
-class RDDInfo(
-    val id: Int,
-    val name: String,
-    val numPartitions: Int,
-    val storageLevel: StorageLevel)
-  extends Ordered[RDDInfo] {
-
-  var numCachedPartitions = 0
-  var memSize = 0L
-  var diskSize = 0L
-  var tachyonSize = 0L
-
-  override def toString = {
-    import Utils.bytesToString
-    ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s;" +
-      "TachyonSize: %s; DiskSize: %s").format(
-        name, id, storageLevel.toString, numCachedPartitions, numPartitions,
-        bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
-  }
-
-  override def compare(that: RDDInfo) = {
-    this.id - that.id
-  }
-}
-
-/* Helper methods for storage-related objects */
-private[spark]
-object StorageUtils {
+/** Helper methods for storage-related objects. */
+private[spark] object StorageUtils {
 
   /**
    * Returns basic information of all RDDs persisted in the given SparkContext. This does not

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index c1e69f6..6cb43c0 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -32,7 +32,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
   def render(request: HttpServletRequest): Seq[Node] = {
     val storageStatusList = listener.storageStatusList
     val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _)
-    val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_ + _)
+    val memUsed = storageStatusList.map(_.memUsed).fold(0L)(_ + _)
     val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _)
     val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
     val execInfoSorted = execInfo.sortBy(_.getOrElse("Executor ID", ""))
@@ -106,9 +106,9 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
     val execId = status.blockManagerId.executorId
     val hostPort = status.blockManagerId.hostPort
     val rddBlocks = status.blocks.size
-    val memUsed = status.memUsed()
+    val memUsed = status.memUsed
     val maxMem = status.maxMem
-    val diskUsed = status.diskUsed()
+    val diskUsed = status.diskUsed
     val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0)
     val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
     val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index c04ef0a..07ec297 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -66,8 +66,8 @@ private[ui] class StorageListener(storageStatusListener: StorageStatusListener)
   }
 
   override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
-    val rddInfo = stageSubmitted.stageInfo.rddInfo
-    _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo)
+    val rddInfos = stageSubmitted.stageInfo.rddInfos
+    rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
   }
 
   override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 465835e..9aed3e0 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -176,7 +176,7 @@ private[spark] object JsonProtocol {
    * -------------------------------------------------------------------- */
 
   def stageInfoToJson(stageInfo: StageInfo): JValue = {
-    val rddInfo = rddInfoToJson(stageInfo.rddInfo)
+    val rddInfo = JArray(stageInfo.rddInfos.map(rddInfoToJson).toList)
     val submissionTime = stageInfo.submissionTime.map(JInt(_)).getOrElse(JNothing)
     val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing)
     val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing)
@@ -208,7 +208,8 @@ private[spark] object JsonProtocol {
       taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing)
     val shuffleWriteMetrics =
       taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
-    val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
+    val updatedBlocks =
+      taskMetrics.updatedBlocks.map { blocks =>
         JArray(blocks.toList.map { case (id, status) =>
           ("Block ID" -> id.toString) ~
           ("Status" -> blockStatusToJson(status))
@@ -467,13 +468,13 @@ private[spark] object JsonProtocol {
     val stageId = (json \ "Stage ID").extract[Int]
     val stageName = (json \ "Stage Name").extract[String]
     val numTasks = (json \ "Number of Tasks").extract[Int]
-    val rddInfo = rddInfoFromJson(json \ "RDD Info")
+    val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
     val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
     val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
     val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
     val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean]
 
-    val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfo)
+    val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos)
     stageInfo.submissionTime = submissionTime
     stageInfo.completionTime = completionTime
     stageInfo.failureReason = failureReason
@@ -518,13 +519,14 @@ private[spark] object JsonProtocol {
       Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)
     metrics.shuffleWriteMetrics =
       Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
-    metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
-      value.extract[List[JValue]].map { block =>
-        val id = BlockId((block \ "Block ID").extract[String])
-        val status = blockStatusFromJson(block \ "Status")
-        (id, status)
+    metrics.updatedBlocks =
+      Utils.jsonOption(json \ "Updated Blocks").map { value =>
+        value.extract[List[JValue]].map { block =>
+          val id = BlockId((block \ "Block ID").extract[String])
+          val status = blockStatusFromJson(block \ "Status")
+          (id, status)
+        }
       }
-    }
     metrics
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index b86923f..fd5b090 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -60,7 +60,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
 
     whenExecuting(blockManager) {
       val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
-        taskMetrics = TaskMetrics.empty())
+        taskMetrics = TaskMetrics.empty)
       val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
       assert(value.toList === List(1, 2, 3, 4))
     }
@@ -73,7 +73,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
 
     whenExecuting(blockManager) {
       val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
-        taskMetrics = TaskMetrics.empty())
+        taskMetrics = TaskMetrics.empty)
       val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
       assert(value.toList === List(5, 6, 7))
     }
@@ -87,7 +87,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
 
     whenExecuting(blockManager) {
       val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
-        taskMetrics = TaskMetrics.empty())
+        taskMetrics = TaskMetrics.empty)
       val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
       assert(value.toList === List(1, 2, 3, 4))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
index dfe0575..0bb6a6b 100644
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -179,7 +179,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
       val hadoopPart1 = generateFakeHadoopPartition()
       val pipedRdd = new PipedRDD(nums, "printenv " + varName)
       val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
-        taskMetrics = TaskMetrics.empty())
+        taskMetrics = TaskMetrics.empty)
       val rddIter = pipedRdd.compute(hadoopPart1, tContext)
       val arr = rddIter.toArray
       assert(arr(0) == "/some/path")

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 1901330..d7c9034 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.rdd
 
-import scala.collection.mutable.HashMap
-import scala.collection.parallel.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.reflect.ClassTag
 
 import org.scalatest.FunSuite
-import org.scalatest.concurrent.Timeouts._
-import org.scalatest.time.{Millis, Span}
 
 import org.apache.spark._
 import org.apache.spark.SparkContext._
@@ -153,7 +151,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
         if (shouldFail) {
           throw new Exception("injected failure")
         } else {
-          return Array(1, 2, 3, 4).iterator
+          Array(1, 2, 3, 4).iterator
         }
       }
     }.cache()
@@ -568,4 +566,144 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     val ids = ranked.map(_._1).distinct().collect()
     assert(ids.length === n)
   }
+
+  test("getNarrowAncestors") {
+    val rdd1 = sc.parallelize(1 to 100, 4)
+    val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1)
+    val rdd3 = rdd2.map(_ - 1).filter(_ < 50).map(i => (i, i))
+    val rdd4 = rdd3.reduceByKey(_ + _)
+    val rdd5 = rdd4.mapValues(_ + 1).mapValues(_ + 2).mapValues(_ + 3)
+    val ancestors1 = rdd1.getNarrowAncestors
+    val ancestors2 = rdd2.getNarrowAncestors
+    val ancestors3 = rdd3.getNarrowAncestors
+    val ancestors4 = rdd4.getNarrowAncestors
+    val ancestors5 = rdd5.getNarrowAncestors
+
+    // Simple dependency tree with a single branch
+    assert(ancestors1.size === 0)
+    assert(ancestors2.size === 2)
+    assert(ancestors2.count(_.isInstanceOf[ParallelCollectionRDD[_]]) === 1)
+    assert(ancestors2.count(_.isInstanceOf[FilteredRDD[_]]) === 1)
+    assert(ancestors3.size === 5)
+    assert(ancestors3.count(_.isInstanceOf[ParallelCollectionRDD[_]]) === 1)
+    assert(ancestors3.count(_.isInstanceOf[FilteredRDD[_]]) === 2)
+    assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)
+
+    // Any ancestors before the shuffle are not considered
+    assert(ancestors4.size === 1)
+    assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
+    assert(ancestors5.size === 4)
+    assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
+    assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 1)
+    assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
+  }
+
+  test("getNarrowAncestors with multiple parents") {
+    val rdd1 = sc.parallelize(1 to 100, 5)
+    val rdd2 = sc.parallelize(1 to 200, 10).map(_ + 1)
+    val rdd3 = sc.parallelize(1 to 300, 15).filter(_ > 50)
+    val rdd4 = rdd1.map(i => (i, i))
+    val rdd5 = rdd2.map(i => (i, i))
+    val rdd6 = sc.union(rdd1, rdd2)
+    val rdd7 = sc.union(rdd1, rdd2, rdd3)
+    val rdd8 = sc.union(rdd6, rdd7)
+    val rdd9 = rdd4.join(rdd5)
+    val ancestors6 = rdd6.getNarrowAncestors
+    val ancestors7 = rdd7.getNarrowAncestors
+    val ancestors8 = rdd8.getNarrowAncestors
+    val ancestors9 = rdd9.getNarrowAncestors
+
+    // Simple dependency tree with multiple branches
+    assert(ancestors6.size === 3)
+    assert(ancestors6.count(_.isInstanceOf[ParallelCollectionRDD[_]]) === 2)
+    assert(ancestors6.count(_.isInstanceOf[MappedRDD[_, _]]) === 1)
+    assert(ancestors7.size === 5)
+    assert(ancestors7.count(_.isInstanceOf[ParallelCollectionRDD[_]]) === 3)
+    assert(ancestors7.count(_.isInstanceOf[MappedRDD[_, _]]) === 1)
+    assert(ancestors7.count(_.isInstanceOf[FilteredRDD[_]]) === 1)
+
+    // Dependency tree with duplicate nodes (e.g. rdd1 should not be reported twice)
+    assert(ancestors8.size === 7)
+    assert(ancestors8.count(_.isInstanceOf[MappedRDD[_, _]]) === 1)
+    assert(ancestors8.count(_.isInstanceOf[FilteredRDD[_]]) === 1)
+    assert(ancestors8.count(_.isInstanceOf[UnionRDD[_]]) === 2)
+    assert(ancestors8.count(_.isInstanceOf[ParallelCollectionRDD[_]]) === 3)
+    assert(ancestors8.count(_ == rdd1) === 1)
+    assert(ancestors8.count(_ == rdd2) === 1)
+    assert(ancestors8.count(_ == rdd3) === 1)
+
+    // Any ancestors before the shuffle are not considered
+    assert(ancestors9.size === 2)
+    assert(ancestors9.count(_.isInstanceOf[CoGroupedRDD[_]]) === 1)
+    assert(ancestors9.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 1)
+  }
+
+  /**
+   * This tests for the pathological condition in which the RDD dependency graph is cyclical.
+   *
+   * Since RDD is part of the public API, applications may actually implement RDDs that allow
+   * such graphs to be constructed. In such cases, getNarrowAncestor should not simply hang.
+   */
+  test("getNarrowAncestors with cycles") {
+    val rdd1 = new CyclicalDependencyRDD[Int]
+    val rdd2 = new CyclicalDependencyRDD[Int]
+    val rdd3 = new CyclicalDependencyRDD[Int]
+    val rdd4 = rdd3.map(_ + 1).filter(_ > 10).map(_ + 2).filter(_ % 5 > 1)
+    val rdd5 = rdd4.map(_ + 2).filter(_ > 20)
+    val rdd6 = sc.union(rdd1, rdd2, rdd3).map(_ + 4).union(rdd5).union(rdd4)
+
+    // Simple cyclical dependency
+    rdd1.addDependency(new OneToOneDependency[Int](rdd2))
+    rdd2.addDependency(new OneToOneDependency[Int](rdd1))
+    val ancestors1 = rdd1.getNarrowAncestors
+    val ancestors2 = rdd2.getNarrowAncestors
+    assert(ancestors1.size === 1)
+    assert(ancestors1.count(_ == rdd2) === 1)
+    assert(ancestors1.count(_ == rdd1) === 0)
+    assert(ancestors2.size === 1)
+    assert(ancestors2.count(_ == rdd1) === 1)
+    assert(ancestors2.count(_ == rdd2) === 0)
+
+    // Cycle involving a longer chain
+    rdd3.addDependency(new OneToOneDependency[Int](rdd4))
+    val ancestors3 = rdd3.getNarrowAncestors
+    val ancestors4 = rdd4.getNarrowAncestors
+    assert(ancestors3.size === 4)
+    assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)
+    assert(ancestors3.count(_.isInstanceOf[FilteredRDD[_]]) === 2)
+    assert(ancestors3.count(_ == rdd3) === 0)
+    assert(ancestors4.size === 4)
+    assert(ancestors4.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)
+    assert(ancestors4.count(_.isInstanceOf[FilteredRDD[_]]) === 1)
+    assert(ancestors4.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 1)
+    assert(ancestors4.count(_ == rdd3) === 1)
+    assert(ancestors4.count(_ == rdd4) === 0)
+
+    // Cycles that do not involve the root
+    val ancestors5 = rdd5.getNarrowAncestors
+    assert(ancestors5.size === 6)
+    assert(ancestors5.count(_.isInstanceOf[MappedRDD[_, _]]) === 3)
+    assert(ancestors5.count(_.isInstanceOf[FilteredRDD[_]]) === 2)
+    assert(ancestors5.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 1)
+    assert(ancestors4.count(_ == rdd3) === 1)
+
+    // Complex cyclical dependency graph (combination of all of the above)
+    val ancestors6 = rdd6.getNarrowAncestors
+    assert(ancestors6.size === 12)
+    assert(ancestors6.count(_.isInstanceOf[UnionRDD[_]]) === 2)
+    assert(ancestors6.count(_.isInstanceOf[MappedRDD[_, _]]) === 4)
+    assert(ancestors6.count(_.isInstanceOf[FilteredRDD[_]]) === 3)
+    assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3)
+  }
+
+  /** A contrived RDD that allows the manual addition of dependencies after creation. */
+  private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) {
+    private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty
+    override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty
+    override def getPartitions: Array[Partition] = Array.empty
+    override def getDependencies: Seq[Dependency[_]] = mutableDependencies
+    def addDependency(dep: Dependency[_]) {
+      mutableDependencies += dep
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 36511a9..ab13917 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -133,20 +133,57 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     val rdd1 = sc.parallelize(1 to 100, 4)
     val rdd2 = rdd1.map(_.toString)
     rdd2.setName("Target RDD")
-    rdd2.count
+    rdd2.count()
 
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
 
     listener.stageInfos.size should be {1}
     val (stageInfo, taskInfoMetrics) = listener.stageInfos.head
-    stageInfo.rddInfo.name should be {"Target RDD"}
+    stageInfo.rddInfos.size should be {2}
+    stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
+    stageInfo.rddInfos.exists(_.name == "Target RDD") should be {true}
     stageInfo.numTasks should be {4}
-    stageInfo.rddInfo.numPartitions should be {4}
     stageInfo.submissionTime should be ('defined)
     stageInfo.completionTime should be ('defined)
     taskInfoMetrics.length should be {4}
   }
 
+  test("basic creation of StageInfo with shuffle") {
+    val listener = new SaveStageAndTaskInfo
+    sc.addSparkListener(listener)
+    val rdd1 = sc.parallelize(1 to 100, 4)
+    val rdd2 = rdd1.filter(_ % 2 == 0).map(i => (i, i))
+    val rdd3 = rdd2.reduceByKey(_ + _)
+    rdd1.setName("Un")
+    rdd2.setName("Deux")
+    rdd3.setName("Trois")
+
+    rdd1.count()
+    assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+    listener.stageInfos.size should be {1}
+    val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get
+    stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD
+    stageInfo1.rddInfos.forall(_.numPartitions == 4) should be {true}
+    stageInfo1.rddInfos.exists(_.name == "Un") should be {true}
+    listener.stageInfos.clear()
+
+    rdd2.count()
+    assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+    listener.stageInfos.size should be {1}
+    val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get
+    stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD
+    stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true}
+    stageInfo2.rddInfos.exists(_.name == "Deux") should be {true}
+    listener.stageInfos.clear()
+
+    rdd3.count()
+    listener.stageInfos.size should be {2} // Shuffle map stage + result stage
+    val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get
+    stageInfo3.rddInfos.size should be {2} // ShuffledRDD, MapPartitionsRDD
+    stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
+    stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}
+  }
+
   test("StageInfo with fewer tasks than partitions") {
     val listener = new SaveStageAndTaskInfo
     sc.addSparkListener(listener)
@@ -159,7 +196,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     listener.stageInfos.size should be {1}
     val (stageInfo, _) = listener.stageInfos.head
     stageInfo.numTasks should be {2}
-    stageInfo.rddInfo.numPartitions should be {4}
+    stageInfo.rddInfos.size should be {2}
+    stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
   }
 
   test("local metrics") {
@@ -167,7 +205,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     sc.addSparkListener(listener)
     sc.addSparkListener(new StatsReportListener)
     // just to make sure some of the tasks take a noticeable amount of time
-    val w = {i:Int =>
+    val w = { i: Int =>
       if (i == 0)
         Thread.sleep(100)
       i
@@ -199,7 +237,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
       checkNonZeroAvg(
         taskInfoMetrics.map(_._2.executorDeserializeTime),
         stageInfo + " executorDeserializeTime")
-      if (stageInfo.rddInfo.name == d4.name) {
+      if (stageInfo.rddInfos.exists(_.name == d4.name)) {
         checkNonZeroAvg(
           taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime),
           stageInfo + " fetchWaitTime")
@@ -207,11 +245,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
 
       taskInfoMetrics.foreach { case (taskInfo, taskMetrics) =>
         taskMetrics.resultSize should be > (0l)
-        if (stageInfo.rddInfo.name == d2.name || stageInfo.rddInfo.name == d3.name) {
+        if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) {
           taskMetrics.shuffleWriteMetrics should be ('defined)
           taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
         }
-        if (stageInfo.rddInfo.name == d4.name) {
+        if (stageInfo.rddInfos.exists(_.name == d4.name)) {
           taskMetrics.shuffleReadMetrics should be ('defined)
           val sm = taskMetrics.shuffleReadMetrics.get
           sm.totalBlocksFetched should be > (0)

http://git-wip-us.apache.org/repos/asf/spark/blob/2de57387/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 16470bb..3031015 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -231,7 +231,10 @@ class JsonProtocolSuite extends FunSuite {
     assert(info1.submissionTime === info2.submissionTime)
     assert(info1.completionTime === info2.completionTime)
     assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning)
-    assertEquals(info1.rddInfo, info2.rddInfo)
+    assert(info1.rddInfos.size === info2.rddInfos.size)
+    (0 until info1.rddInfos.size).foreach { i =>
+      assertEquals(info1.rddInfos(i), info2.rddInfos(i))
+    }
   }
 
   private def assertEquals(info1: RDDInfo, info2: RDDInfo) {
@@ -434,7 +437,8 @@ class JsonProtocolSuite extends FunSuite {
   }
 
   private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = {
-    new StageInfo(a, "greetings", b, makeRddInfo(a, b, c, d, e))
+    val rddInfos = (1 to a % 5).map { i => makeRddInfo(a % i, b % i, c % i, d % i, e % i) }
+    new StageInfo(a, "greetings", b, rddInfos)
   }
 
   private def makeTaskInfo(a: Long, b: Int, c: Long) = {