You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/09/14 12:13:20 UTC

[GitHub] [spark] Ngone51 commented on a change in pull request #29747: [SPARK-31848][CORE][TEST] DAGSchedulerSuite: Break down the very huge test file

Ngone51 commented on a change in pull request #29747:
URL: https://github.com/apache/spark/pull/29747#discussion_r487854397



##########
File path: core/src/test/scala/org/apache/spark/scheduler/IndeterminateStageSuite.scala
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable.Map
+
+import org.apache.spark._
+import org.apache.spark.internal.config
+import org.apache.spark.scheduler.DAGSchedulerTestHelper.makeBlockManagerId
+
+class IndeterminateStageSuite extends DAGSchedulerTestHelper {
+
+  test("SPARK-25341: abort stage while using old fetch protocol") {
+    conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true")
+    // Construct the scenario of indeterminate stage fetch failed.
+    constructIndeterminateStageFetchFailed()
+    // The job should fail because Spark can't rollback the shuffle map stage while
+    // using old protocol.
+    assert(failure != null && failure.getMessage.contains(
+      "Spark can only do this while using the new shuffle block fetching protocol"))
+  }
+
+  test("SPARK-25341: continuous indeterminate stage roll back") {
+    // shuffleMapRdd1/2/3 are all indeterminate.
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
+    val shuffleId1 = shuffleDep1.shuffleId
+
+    val shuffleMapRdd2 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker, indeterminate = true)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
+    val shuffleId2 = shuffleDep2.shuffleId
+
+    val shuffleMapRdd3 = new MyRDD(
+      sc, 2, List(shuffleDep2), tracker = mapOutputTracker, indeterminate = true)
+    val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2))
+    val shuffleId3 = shuffleDep3.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1), properties = new Properties())
+
+    // Finish the first 2 shuffle map stages.
+    completeShuffleMapStageSuccessfully(0, 0, 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+    completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostB", "hostD"))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+
+    // Executor lost on hostB, both of stage 0 and 1 should be reran.
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0L, 0, 0, "ignored"),
+      null))
+    mapOutputTracker.removeOutputsOnHost("hostB")
+
+    assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2))
+    scheduler.resubmitFailedStages()
+
+    def checkAndCompleteRetryStage(
+                                    taskSetIndex: Int,
+                                    stageId: Int,
+                                    shuffleId: Int): Unit = {

Review comment:
       nit: 4 indents.

##########
File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala
##########
@@ -0,0 +1,614 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.annotation.meta.param
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.util.control.NonFatal
+
+import org.mockito.Mockito.spy
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorMetrics
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{AccumulatorV2, CallSite}
+
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+  extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+  override def post(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
+    }
+  }
+
+  override def onError(e: Throwable): Unit = {
+    logError("Error in DAGSchedulerEventLoop: ", e)
+    dagScheduler.stop()
+    throw e
+  }
+
+}
+
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    Iterator.empty
+}
+
+/**
+ * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately
+ * because, in this test suite, it won't be the same as sc.env.mapOutputTracker.
+ */
+class MyRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    throw new RuntimeException("should not be reached")
+
+  override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition {
+    override def index: Int = i
+  }).toArray
+
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
+  override def toString: String = "DAGSchedulerSuiteRDD " + id
+}
+
+class DAGSchedulerSuiteDummyException extends Exception
+
+class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits {
+
+  import DAGSchedulerTestHelper._
+
+  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
+  implicit val defaultSignaler: Signaler = ThreadSignaler
+
+  private var firstInit: Boolean = _
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+
+  /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
+  val cancelledStages = new HashSet[Int]()
+
+  val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
+
+  val taskScheduler = new TaskScheduler() {
+    override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
+    override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
+    override def start() = {}
+    override def stop() = {}
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
+        blockManagerId: BlockManagerId,
+        executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
+      cancelledStages += stageId
+    }
+    override def killTaskAttempt(
+      taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
+    override def killAllTaskAttempts(
+      stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
+    override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
+      taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
+        val tasks = ts.tasks.filter(_.partitionId == partitionId)
+        assert(tasks.length == 1)
+        tasksMarkedAsCompleted += tasks.head
+      }
+    }
+    override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+    override def defaultParallelism() = 2
+    override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
+    override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
+    override def applicationAttemptId(): Option[String] = None
+    override def executorDecommission(
+      executorId: String,
+      decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+    override def getExecutorDecommissionState(
+      executorId: String): Option[ExecutorDecommissionState] = None
+  }
+
+  /**
+   * Listeners which records some information to verify in UTs. Getter-kind methods in this class
+   * ensures the value is returned after ensuring there's no event to process, as well as the
+   * value is immutable: prevent showing odd result by race condition.
+   */
+  class EventInfoRecordingListener extends SparkListener {
+    private val _submittedStageInfos = new HashSet[StageInfo]
+    private val _successfulStages = new HashSet[Int]
+    private val _failedStages = new ArrayBuffer[Int]
+    private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+    private val _endedTasks = new HashSet[Long]
+
+    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+      _submittedStageInfos += stageSubmitted.stageInfo
+    }
+
+    override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+      val stageInfo = stageCompleted.stageInfo
+      _stageByOrderOfExecution += stageInfo.stageId
+      if (stageInfo.failureReason.isEmpty) {
+        _successfulStages += stageInfo.stageId
+      } else {
+        _failedStages += stageInfo.stageId
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+      _endedTasks += taskEnd.taskInfo.taskId
+    }
+
+    def submittedStageInfos: Set[StageInfo] = {
+      waitForListeners()
+      _submittedStageInfos.toSet
+    }
+
+    def successfulStages: Set[Int] = {
+      waitForListeners()
+      _successfulStages.toSet
+    }
+
+    def failedStages: List[Int] = {
+      waitForListeners()
+      _failedStages.toList
+    }
+
+    def stageByOrderOfExecution: List[Int] = {
+      waitForListeners()
+      _stageByOrderOfExecution.toList
+    }
+
+    def endedTasks: Set[Long] = {
+      waitForListeners()
+      _endedTasks.toSet
+    }
+
+    private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty()
+  }
+
+  var sparkListener: EventInfoRecordingListener = null
+
+  var blockManagerMaster: BlockManagerMaster = null
+  var mapOutputTracker: MapOutputTrackerMaster = null
+  var broadcastManager: BroadcastManager = null
+  var securityMgr: SecurityManager = null
+  var scheduler: DAGScheduler = null
+  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  class MyBlockManagerMaster(conf: SparkConf) extends BlockManagerMaster(null, null, conf, true) {
+    override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = {
+      blockIds.map {
+        _.asRDDId.map { id => (id.rddId -> id.splitIndex)
+        }.flatMap { key => cacheLocations.get(key)
+        }.getOrElse(Seq())
+      }.toIndexedSeq
+    }
+    override def removeExecutor(execId: String): Unit = {
+      // don't need to propagate to the driver, which we don't have
+    }
+  }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val jobListener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+    override def jobFailed(exception: Exception) = { failure = exception }
+  }
+
+  /** A simple helper class for creating custom JobListeners */
+  class SimpleListener extends JobListener {

Review comment:
       This is only used by `DAGSchedulerSuite`. Shall we keep it at `DAGSchedulerSuite`?

##########
File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala
##########
@@ -0,0 +1,614 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.annotation.meta.param
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.util.control.NonFatal
+
+import org.mockito.Mockito.spy
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorMetrics
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{AccumulatorV2, CallSite}
+
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+  extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+  override def post(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
+    }
+  }
+
+  override def onError(e: Throwable): Unit = {
+    logError("Error in DAGSchedulerEventLoop: ", e)
+    dagScheduler.stop()
+    throw e
+  }
+
+}
+
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    Iterator.empty
+}
+
+/**
+ * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately
+ * because, in this test suite, it won't be the same as sc.env.mapOutputTracker.
+ */
+class MyRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    throw new RuntimeException("should not be reached")
+
+  override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition {
+    override def index: Int = i
+  }).toArray
+
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
+  override def toString: String = "DAGSchedulerSuiteRDD " + id
+}
+
+class DAGSchedulerSuiteDummyException extends Exception
+
+class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits {

Review comment:
       `DAGSchedulerBaseSuite`? 

##########
File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala
##########
@@ -0,0 +1,614 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.annotation.meta.param
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.util.control.NonFatal
+
+import org.mockito.Mockito.spy
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorMetrics
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{AccumulatorV2, CallSite}
+
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+  extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+  override def post(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
+    }
+  }
+
+  override def onError(e: Throwable): Unit = {
+    logError("Error in DAGSchedulerEventLoop: ", e)
+    dagScheduler.stop()
+    throw e
+  }
+
+}
+
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    Iterator.empty
+}
+
+/**
+ * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately
+ * because, in this test suite, it won't be the same as sc.env.mapOutputTracker.
+ */
+class MyRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    throw new RuntimeException("should not be reached")
+
+  override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition {
+    override def index: Int = i
+  }).toArray
+
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
+  override def toString: String = "DAGSchedulerSuiteRDD " + id
+}
+
+class DAGSchedulerSuiteDummyException extends Exception
+
+class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits {
+
+  import DAGSchedulerTestHelper._
+
+  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
+  implicit val defaultSignaler: Signaler = ThreadSignaler
+
+  private var firstInit: Boolean = _
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+
+  /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
+  val cancelledStages = new HashSet[Int]()
+
+  val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
+
+  val taskScheduler = new TaskScheduler() {
+    override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
+    override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
+    override def start() = {}
+    override def stop() = {}
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
+        blockManagerId: BlockManagerId,
+        executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
+      cancelledStages += stageId
+    }
+    override def killTaskAttempt(
+      taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
+    override def killAllTaskAttempts(
+      stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
+    override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
+      taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
+        val tasks = ts.tasks.filter(_.partitionId == partitionId)
+        assert(tasks.length == 1)
+        tasksMarkedAsCompleted += tasks.head
+      }
+    }
+    override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+    override def defaultParallelism() = 2
+    override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
+    override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
+    override def applicationAttemptId(): Option[String] = None
+    override def executorDecommission(
+      executorId: String,
+      decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+    override def getExecutorDecommissionState(
+      executorId: String): Option[ExecutorDecommissionState] = None
+  }
+
+  /**
+   * Listeners which records some information to verify in UTs. Getter-kind methods in this class
+   * ensures the value is returned after ensuring there's no event to process, as well as the
+   * value is immutable: prevent showing odd result by race condition.
+   */
+  class EventInfoRecordingListener extends SparkListener {
+    private val _submittedStageInfos = new HashSet[StageInfo]
+    private val _successfulStages = new HashSet[Int]
+    private val _failedStages = new ArrayBuffer[Int]
+    private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+    private val _endedTasks = new HashSet[Long]
+
+    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+      _submittedStageInfos += stageSubmitted.stageInfo
+    }
+
+    override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+      val stageInfo = stageCompleted.stageInfo
+      _stageByOrderOfExecution += stageInfo.stageId
+      if (stageInfo.failureReason.isEmpty) {
+        _successfulStages += stageInfo.stageId
+      } else {
+        _failedStages += stageInfo.stageId
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+      _endedTasks += taskEnd.taskInfo.taskId
+    }
+
+    def submittedStageInfos: Set[StageInfo] = {
+      waitForListeners()
+      _submittedStageInfos.toSet
+    }
+
+    def successfulStages: Set[Int] = {
+      waitForListeners()
+      _successfulStages.toSet
+    }
+
+    def failedStages: List[Int] = {
+      waitForListeners()
+      _failedStages.toList
+    }
+
+    def stageByOrderOfExecution: List[Int] = {
+      waitForListeners()
+      _stageByOrderOfExecution.toList
+    }
+
+    def endedTasks: Set[Long] = {
+      waitForListeners()
+      _endedTasks.toSet
+    }
+
+    private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty()
+  }
+
+  var sparkListener: EventInfoRecordingListener = null
+
+  var blockManagerMaster: BlockManagerMaster = null
+  var mapOutputTracker: MapOutputTrackerMaster = null
+  var broadcastManager: BroadcastManager = null
+  var securityMgr: SecurityManager = null
+  var scheduler: DAGScheduler = null
+  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  class MyBlockManagerMaster(conf: SparkConf) extends BlockManagerMaster(null, null, conf, true) {
+    override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = {
+      blockIds.map {
+        _.asRDDId.map { id => (id.rddId -> id.splitIndex)
+        }.flatMap { key => cacheLocations.get(key)
+        }.getOrElse(Seq())
+      }.toIndexedSeq
+    }
+    override def removeExecutor(execId: String): Unit = {
+      // don't need to propagate to the driver, which we don't have
+    }
+  }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val jobListener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+    override def jobFailed(exception: Exception) = { failure = exception }
+  }
+
+  /** A simple helper class for creating custom JobListeners */
+  class SimpleListener extends JobListener {
+    val results = new HashMap[Int, Any]
+    var failure: Exception = null
+    override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result)
+    override def jobFailed(exception: Exception): Unit = { failure = exception }
+  }
+
+  class MyMapOutputTrackerMaster(
+      conf: SparkConf,
+      broadcastManager: BroadcastManager)
+    extends MapOutputTrackerMaster(conf, broadcastManager, true) {
+
+    override def sendTracker(message: Any): Unit = {
+      // no-op, just so we can stop this to avoid leaking threads
+    }
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    firstInit = true
+  }
+
+  override def sc: SparkContext = {
+    val sc = super.sc
+    if (firstInit) {
+      init(sc)
+      firstInit = false
+    }
+    sc
+  }
+
+  private def init(sc: SparkContext): Unit = {
+    sparkListener = new EventInfoRecordingListener
+    failure = null
+    sc.addSparkListener(sparkListener)
+    taskSets.clear()
+    tasksMarkedAsCompleted.clear()
+    cancelledStages.clear()
+    cacheLocations.clear()
+    results.clear()
+    securityMgr = new SecurityManager(sc.getConf)
+    broadcastManager = new BroadcastManager(true, sc.getConf, securityMgr)
+    mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, broadcastManager))
+    blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf))
+    scheduler = new DAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
+    dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      scheduler.stop()
+      dagEventProcessLoopTester.stop()
+      mapOutputTracker.stop()
+      broadcastManager.stop()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+  }
+
+  /**
+   * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies.
+   */
+  type PairOfIntsRDD = RDD[(Int, Int)]
+
+  /**
+   * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most iterations of the
+   * DAGScheduler event loop.
+   */
+  protected def runEvent(event: DAGSchedulerEvent): Unit = {
+    // Ensure the initialization of various components
+    sc
+    dagEventProcessLoopTester.post(event)
+  }
+
+  /**
+   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
+   */
+  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+    it.next.asInstanceOf[Tuple2[_, _]]._1
+
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  protected def complete(taskSet: TaskSet, taskEndInfos: Seq[(TaskEndReason, Any)]): Unit = {
+    assert(taskSet.tasks.size >= taskEndInfos.size)
+    for ((result, i) <- taskEndInfos.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2))
+      }
+    }
+  }
+
+  protected def completeWithAccumulator(
+      accumId: Long,
+      taskSet: TaskSet,
+      results: Seq[(TaskEndReason, Any)]): Unit = {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(
+          taskSet.tasks(i),
+          result._1,
+          result._2,
+          Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = accumId))))
+      }
+    }
+  }
+
+  /** Submits a job to the scheduler and returns the job id. */
+  protected def submit(
+      rdd: RDD[_],
+      partitions: Array[Int],
+      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+      listener: JobListener = jobListener,
+      properties: Properties = null): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties))
+    jobId
+  }
+
+  /** Submits a map stage to the scheduler and returns the job id. */
+  protected def submitMapStage(
+      shuffleDep: ShuffleDependency[_, _, _],
+      listener: JobListener = jobListener): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener))
+    jobId
+  }
+
+  /** Sends TaskSetFailed to the scheduler. */
+  protected def failed(taskSet: TaskSet, message: String): Unit = {
+    runEvent(TaskSetFailed(taskSet, message, None))
+  }
+
+  /** Sends JobCancelled to the DAG scheduler. */
+  protected def cancel(jobId: Int): Unit = {
+    runEvent(JobCancelled(jobId, None))
+  }
+
+  /** Make some tasks in task set success and check results. */
+  protected def completeAndCheckAnswer(
+      taskSet: TaskSet,
+      taskEndInfos: Seq[(TaskEndReason, Any)],
+      expected: Map[Int, Any]): Unit = {
+    complete(taskSet, taskEndInfos)
+    assert(this.results === expected)
+  }
+
+  // Helper function to validate state when creating tests for task failures
+  private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet): Unit = {
+    assert(stageAttempt.stageId === stageId)
+    assert(stageAttempt.stageAttemptId == attempt)
+  }
+
+  // Helper functions to extract commonly used code in Fetch Failure test cases
+  protected def setupStageAbortTest(sc: SparkContext): Unit = {

Review comment:
       This is only used by `DAGSchedulerSuite` too. Shall we keep it at `DAGSchedulerSuite`?

##########
File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala
##########
@@ -0,0 +1,614 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.annotation.meta.param
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.util.control.NonFatal
+
+import org.mockito.Mockito.spy
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorMetrics
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{AccumulatorV2, CallSite}
+
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+  extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+  override def post(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
+    }
+  }
+
+  override def onError(e: Throwable): Unit = {
+    logError("Error in DAGSchedulerEventLoop: ", e)
+    dagScheduler.stop()
+    throw e
+  }
+
+}
+
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    Iterator.empty
+}
+
+/**
+ * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately
+ * because, in this test suite, it won't be the same as sc.env.mapOutputTracker.
+ */
+class MyRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    throw new RuntimeException("should not be reached")
+
+  override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition {
+    override def index: Int = i
+  }).toArray
+
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
+  override def toString: String = "DAGSchedulerSuiteRDD " + id
+}
+
+class DAGSchedulerSuiteDummyException extends Exception
+
+class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits {
+
+  import DAGSchedulerTestHelper._
+
+  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
+  implicit val defaultSignaler: Signaler = ThreadSignaler
+
+  private var firstInit: Boolean = _
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+
+  /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
+  val cancelledStages = new HashSet[Int]()
+
+  val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
+
+  val taskScheduler = new TaskScheduler() {
+    override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
+    override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
+    override def start() = {}
+    override def stop() = {}
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
+        blockManagerId: BlockManagerId,
+        executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
+      cancelledStages += stageId
+    }
+    override def killTaskAttempt(
+      taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
+    override def killAllTaskAttempts(
+      stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
+    override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
+      taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
+        val tasks = ts.tasks.filter(_.partitionId == partitionId)
+        assert(tasks.length == 1)
+        tasksMarkedAsCompleted += tasks.head
+      }
+    }
+    override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+    override def defaultParallelism() = 2
+    override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
+    override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
+    override def applicationAttemptId(): Option[String] = None
+    override def executorDecommission(
+      executorId: String,
+      decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+    override def getExecutorDecommissionState(
+      executorId: String): Option[ExecutorDecommissionState] = None
+  }
+
+  /**
+   * Listeners which records some information to verify in UTs. Getter-kind methods in this class
+   * ensures the value is returned after ensuring there's no event to process, as well as the
+   * value is immutable: prevent showing odd result by race condition.
+   */
+  class EventInfoRecordingListener extends SparkListener {
+    private val _submittedStageInfos = new HashSet[StageInfo]
+    private val _successfulStages = new HashSet[Int]
+    private val _failedStages = new ArrayBuffer[Int]
+    private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+    private val _endedTasks = new HashSet[Long]
+
+    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+      _submittedStageInfos += stageSubmitted.stageInfo
+    }
+
+    override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+      val stageInfo = stageCompleted.stageInfo
+      _stageByOrderOfExecution += stageInfo.stageId
+      if (stageInfo.failureReason.isEmpty) {
+        _successfulStages += stageInfo.stageId
+      } else {
+        _failedStages += stageInfo.stageId
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+      _endedTasks += taskEnd.taskInfo.taskId
+    }
+
+    def submittedStageInfos: Set[StageInfo] = {
+      waitForListeners()
+      _submittedStageInfos.toSet
+    }
+
+    def successfulStages: Set[Int] = {
+      waitForListeners()
+      _successfulStages.toSet
+    }
+
+    def failedStages: List[Int] = {
+      waitForListeners()
+      _failedStages.toList
+    }
+
+    def stageByOrderOfExecution: List[Int] = {
+      waitForListeners()
+      _stageByOrderOfExecution.toList
+    }
+
+    def endedTasks: Set[Long] = {
+      waitForListeners()
+      _endedTasks.toSet
+    }
+
+    private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty()
+  }
+
+  var sparkListener: EventInfoRecordingListener = null
+
+  var blockManagerMaster: BlockManagerMaster = null
+  var mapOutputTracker: MapOutputTrackerMaster = null
+  var broadcastManager: BroadcastManager = null
+  var securityMgr: SecurityManager = null
+  var scheduler: DAGScheduler = null
+  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  class MyBlockManagerMaster(conf: SparkConf) extends BlockManagerMaster(null, null, conf, true) {
+    override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = {
+      blockIds.map {
+        _.asRDDId.map { id => (id.rddId -> id.splitIndex)
+        }.flatMap { key => cacheLocations.get(key)
+        }.getOrElse(Seq())
+      }.toIndexedSeq
+    }
+    override def removeExecutor(execId: String): Unit = {
+      // don't need to propagate to the driver, which we don't have
+    }
+  }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val jobListener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+    override def jobFailed(exception: Exception) = { failure = exception }
+  }
+
+  /** A simple helper class for creating custom JobListeners */
+  class SimpleListener extends JobListener {
+    val results = new HashMap[Int, Any]
+    var failure: Exception = null
+    override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result)
+    override def jobFailed(exception: Exception): Unit = { failure = exception }
+  }
+
+  class MyMapOutputTrackerMaster(
+      conf: SparkConf,
+      broadcastManager: BroadcastManager)
+    extends MapOutputTrackerMaster(conf, broadcastManager, true) {
+
+    override def sendTracker(message: Any): Unit = {
+      // no-op, just so we can stop this to avoid leaking threads
+    }
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    firstInit = true
+  }
+
+  override def sc: SparkContext = {
+    val sc = super.sc
+    if (firstInit) {
+      init(sc)
+      firstInit = false
+    }
+    sc
+  }
+
+  private def init(sc: SparkContext): Unit = {
+    sparkListener = new EventInfoRecordingListener
+    failure = null
+    sc.addSparkListener(sparkListener)
+    taskSets.clear()
+    tasksMarkedAsCompleted.clear()
+    cancelledStages.clear()
+    cacheLocations.clear()
+    results.clear()
+    securityMgr = new SecurityManager(sc.getConf)
+    broadcastManager = new BroadcastManager(true, sc.getConf, securityMgr)
+    mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, broadcastManager))
+    blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf))
+    scheduler = new DAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
+    dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      scheduler.stop()
+      dagEventProcessLoopTester.stop()
+      mapOutputTracker.stop()
+      broadcastManager.stop()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+  }
+
+  /**
+   * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies.
+   */
+  type PairOfIntsRDD = RDD[(Int, Int)]
+
+  /**
+   * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most iterations of the
+   * DAGScheduler event loop.
+   */
+  protected def runEvent(event: DAGSchedulerEvent): Unit = {
+    // Ensure the initialization of various components
+    sc
+    dagEventProcessLoopTester.post(event)
+  }
+
+  /**
+   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
+   */
+  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+    it.next.asInstanceOf[Tuple2[_, _]]._1
+
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  protected def complete(taskSet: TaskSet, taskEndInfos: Seq[(TaskEndReason, Any)]): Unit = {
+    assert(taskSet.tasks.size >= taskEndInfos.size)
+    for ((result, i) <- taskEndInfos.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2))
+      }
+    }
+  }
+
+  protected def completeWithAccumulator(
+      accumId: Long,
+      taskSet: TaskSet,
+      results: Seq[(TaskEndReason, Any)]): Unit = {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(
+          taskSet.tasks(i),
+          result._1,
+          result._2,
+          Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = accumId))))
+      }
+    }
+  }
+
+  /** Submits a job to the scheduler and returns the job id. */
+  protected def submit(
+      rdd: RDD[_],
+      partitions: Array[Int],
+      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+      listener: JobListener = jobListener,
+      properties: Properties = null): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties))
+    jobId
+  }
+
+  /** Submits a map stage to the scheduler and returns the job id. */
+  protected def submitMapStage(
+      shuffleDep: ShuffleDependency[_, _, _],
+      listener: JobListener = jobListener): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener))
+    jobId
+  }
+
+  /** Sends TaskSetFailed to the scheduler. */
+  protected def failed(taskSet: TaskSet, message: String): Unit = {
+    runEvent(TaskSetFailed(taskSet, message, None))
+  }
+
+  /** Sends JobCancelled to the DAG scheduler. */
+  protected def cancel(jobId: Int): Unit = {
+    runEvent(JobCancelled(jobId, None))
+  }
+
+  /** Make some tasks in task set success and check results. */
+  protected def completeAndCheckAnswer(
+      taskSet: TaskSet,
+      taskEndInfos: Seq[(TaskEndReason, Any)],
+      expected: Map[Int, Any]): Unit = {
+    complete(taskSet, taskEndInfos)
+    assert(this.results === expected)
+  }
+
+  // Helper function to validate state when creating tests for task failures
+  private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet): Unit = {
+    assert(stageAttempt.stageId === stageId)
+    assert(stageAttempt.stageAttemptId == attempt)
+  }
+
+  // Helper functions to extract commonly used code in Fetch Failure test cases
+  protected def setupStageAbortTest(sc: SparkContext): Unit = {
+    sc.listenerBus.addToSharedQueue(new EndListener())
+    ended = false
+    jobResult = null
+  }
+
+  // Create a new Listener to confirm that the listenerBus sees the JobEnd message
+  // when we abort the stage. This message will also be consumed by the EventLoggingListener
+  // so this will propagate up to the user.
+  var ended = false
+  var jobResult : JobResult = null
+
+  class EndListener extends SparkListener {
+    override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+      jobResult = jobEnd.jobResult
+      ended = true
+    }
+  }
+
+  /**
+   * Common code to get the next stage attempt, confirm it's the one we expect, and complete it
+   * successfully.
+   *
+   * @param stageId - The current stageId
+   * @param attemptIdx - The current attempt count
+   * @param numShufflePartitions - The number of partitions in the next stage
+   * @param hostNames - Host on which each task in the task set is executed
+   */
+  protected def completeShuffleMapStageSuccessfully(

Review comment:
       This is only used by `DAGSchedulerSuite` too. Shall we keep it at `DAGSchedulerSuite`? Could you check others?
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org