You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/09/11 04:06:31 UTC

spark git commit: [SPARK-23243][SPARK-20715][CORE][2.2] Fix RDD.repartition() data correctness issue

Repository: spark
Updated Branches:
  refs/heads/branch-2.2 af41dedc6 -> 3158fc3ce


[SPARK-23243][SPARK-20715][CORE][2.2] Fix RDD.repartition() data correctness issue

## What changes were proposed in this pull request?

Back port of #22354 and #17955 to 2.2 (#22354 depends on methods introduced by #17955).

-------

An alternative fix for #21698

When Spark rerun tasks for an RDD, there are 3 different behaviors:
1. determinate. Always return the same result with same order when rerun.
2. unordered. Returns same data set in random order when rerun.
3. indeterminate. Returns different result when rerun.

Normally Spark doesn't need to care about it. Spark runs stages one by one, when a task is failed, just rerun it. Although the rerun task may return a different result, users will not be surprised.

However, Spark may rerun a finished stage when seeing fetch failures. When this happens, Spark needs to rerun all the tasks of all the succeeding stages if the RDD output is indeterminate, because the input of the succeeding stages has been changed.

If the RDD output is determinate, we only need to rerun the failed tasks of the succeeding stages, because the input doesn't change.

If the RDD output is unordered, it's same as determinate, because shuffle partitioner is always deterministic(round-robin partitioner is not a shuffle partitioner that extends `org.apache.spark.Partitioner`), so the reducers will still get the same input data set.

This PR fixed the failure handling for `repartition`, to avoid correctness issues.

For `repartition`, it applies a stateful map function to generate a round-robin id, which is order sensitive and makes the RDD's output indeterminate. When the stage contains `repartition` reruns, we must also rerun all the tasks of all the succeeding stages.

**future improvement:**
1. Currently we can't rollback and rerun a shuffle map stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25341
2. Currently we can't rollback and rerun a result stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25342
3. We should provide public API to allow users to tag the random level of the RDD's computing function.

## How was this patch tested?

a new test case

Closes #22382 from bersprockets/SPARK-23243-2.2.

Lead-authored-by: Bruce Robbins <be...@gmail.com>
Co-authored-by: Josh Rosen <jo...@databricks.com>
Co-authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3158fc3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3158fc3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3158fc3c

Branch: refs/heads/branch-2.2
Commit: 3158fc3ce390f96d8c65d70bcdf9ac9aa26be24b
Parents: af41ded
Author: Bruce Robbins <be...@gmail.com>
Authored: Tue Sep 11 12:06:19 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Sep 11 12:06:19 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     | 636 +++++++++++--------
 .../scala/org/apache/spark/Partitioner.scala    |   3 +
 .../org/apache/spark/executor/Executor.scala    |  10 +-
 .../org/apache/spark/rdd/MapPartitionsRDD.scala |  21 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 100 ++-
 .../apache/spark/scheduler/DAGScheduler.scala   | 110 ++--
 .../spark/scheduler/ShuffleMapStage.scala       |  76 +--
 .../spark/scheduler/TaskSchedulerImpl.scala     |   2 +-
 .../apache/spark/MapOutputTrackerSuite.scala    |   6 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   |   3 +-
 .../spark/scheduler/BlacklistTrackerSuite.scala |   3 +-
 .../spark/scheduler/DAGSchedulerSuite.scala     | 169 ++++-
 .../execution/exchange/ShuffleExchange.scala    |  17 +-
 13 files changed, 750 insertions(+), 406 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4ef6656..3e10b9e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
 import org.apache.spark.util._
 
+/**
+ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single
+ * ShuffleMapStage.
+ *
+ * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of
+ * serialized map statuses in order to speed up tasks' requests for map output statuses.
+ *
+ * All public methods of this class are thread-safe.
+ */
+private class ShuffleStatus(numPartitions: Int) {
+
+  // All accesses to the following state must be guarded with `this.synchronized`.
+
+  /**
+   * MapStatus for each partition. The index of the array is the map partition id.
+   * Each value in the array is the MapStatus for a partition, or null if the partition
+   * is not available. Even though in theory a task may run multiple times (due to speculation,
+   * stage retries, etc.), in practice the likelihood of a map output being available at multiple
+   * locations is so small that we choose to ignore that case and store only a single location
+   * for each output.
+   */
+  private[this] val mapStatuses = new Array[MapStatus](numPartitions)
+
+  /**
+   * The cached result of serializing the map statuses array. This cache is lazily populated when
+   * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
+   */
+  private[this] var cachedSerializedMapStatus: Array[Byte] = _
+
+  /**
+   * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]]
+   * serializes the map statuses array it may detect that the result is too large to send in a
+   * single RPC, in which case it places the serialized array into a broadcast variable and then
+   * sends a serialized broadcast variable instead. This variable holds a reference to that
+   * broadcast variable in order to keep it from being garbage collected and to allow for it to be
+   * explicitly destroyed later on when the ShuffleMapStage is garbage-collected.
+   */
+  private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
+
+  /**
+   * Counter tracking the number of partitions that have output. This is a performance optimization
+   * to avoid having to count the number of non-null entries in the `mapStatuses` array and should
+   * be equivalent to`mapStatuses.count(_ ne null)`.
+   */
+  private[this] var _numAvailableOutputs: Int = 0
+
+  /**
+   * Register a map output. If there is already a registered location for the map output then it
+   * will be replaced by the new location.
+   */
+  def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
+    if (mapStatuses(mapId) == null) {
+      _numAvailableOutputs += 1
+      invalidateSerializedMapOutputStatusCache()
+    }
+    mapStatuses(mapId) = status
+  }
+
+  /**
+   * Remove the map output which was served by the specified block manager.
+   * This is a no-op if there is no registered map output or if the registered output is from a
+   * different block manager.
+   */
+  def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
+    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
+      _numAvailableOutputs -= 1
+      mapStatuses(mapId) = null
+      invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all map outputs associated with the specified executor. Note that this will also
+   * remove outputs which are served by an external shuffle server (if one exists), as they are
+   * still registered with that execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = synchronized {
+    for (mapId <- 0 until mapStatuses.length) {
+      if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) {
+        _numAvailableOutputs -= 1
+        mapStatuses(mapId) = null
+        invalidateSerializedMapOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle outputs.
+   */
+  def numAvailableOutputs: Int = synchronized {
+    _numAvailableOutputs
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be computed).
+   */
+  def findMissingPartitions(): Seq[Int] = synchronized {
+    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
+    assert(missing.size == numPartitions - _numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
+    missing
+  }
+
+  /**
+   * Serializes the mapStatuses array into an efficient compressed format. See the comments on
+   * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format.
+   *
+   * This method is designed to be called multiple times and implements caching in order to speed
+   * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
+   * serialize the map statuses then serialization will only be performed in a single thread and all
+   * other threads will block until the cache is populated.
+   */
+  def serializedMapStatus(
+      broadcastManager: BroadcastManager,
+      isLocal: Boolean,
+      minBroadcastSize: Int): Array[Byte] = synchronized {
+    if (cachedSerializedMapStatus eq null) {
+      val serResult = MapOutputTracker.serializeMapStatuses(
+          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
+      cachedSerializedMapStatus = serResult._1
+      cachedSerializedBroadcast = serResult._2
+    }
+    cachedSerializedMapStatus
+  }
+
+  // Used in testing.
+  def hasCachedSerializedBroadcast: Boolean = synchronized {
+    cachedSerializedBroadcast != null
+  }
+
+  /**
+   * Helper function which provides thread-safe access to the mapStatuses array.
+   * The function should NOT mutate the array.
+   */
+  def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
+    f(mapStatuses)
+  }
+
+  /**
+   * Clears the cached serialized map output statuses.
+   */
+  def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
+    if (cachedSerializedBroadcast != null) {
+      cachedSerializedBroadcast.destroy()
+      cachedSerializedBroadcast = null
+    }
+    cachedSerializedMapStatus = null
+  }
+}
+
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
@@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint(
 }
 
 /**
- * Class that keeps track of the location of the map output of
- * a stage. This is abstract because different versions of MapOutputTracker
- * (driver and executor) use different HashMap to store its metadata.
- */
+ * Class that keeps track of the location of the map output of a stage. This is abstract because the
+ * driver and executor have different versions of the MapOutputTracker. In principle the driver-
+ * and executor-side classes don't need to share a common base class; the current shared base class
+ * is maintained primarily for backwards-compatibility in order to avoid having to update existing
+ * test code.
+*/
 private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
-
   /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
   var trackerEndpoint: RpcEndpointRef = _
 
   /**
-   * This HashMap has different behavior for the driver and the executors.
-   *
-   * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks.
-   * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the
-   * driver's corresponding HashMap.
-   *
-   * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
-   * thread-safe map.
-   */
-  protected val mapStatuses: Map[Int, Array[MapStatus]]
-
-  /**
-   * Incremented every time a fetch fails so that client nodes know to clear
-   * their cache of map output locations if this happens.
+   * The driver-side counter is incremented every time that a map output is lost. This value is sent
+   * to executors as part of tasks, where executors compare the new epoch number to the highest
+   * epoch number that they received in the past. If the new epoch number is higher then executors
+   * will clear their local caches of map output statuses and will re-fetch (possibly updated)
+   * statuses from the driver.
    */
   protected var epoch: Long = 0
   protected val epochLock = new AnyRef
 
-  /** Remembers which map output locations are currently being fetched on an executor. */
-  private val fetching = new HashSet[Int]
-
   /**
    * Send a message to the trackerEndpoint and get its result within a default timeout, or
    * throw a SparkException if this fails.
@@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
     }
   }
 
-  /**
-   * Called from executors to get the server URIs and output sizes for each shuffle block that
-   * needs to be read from a given reduce task.
-   *
-   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
-   *         and the second item is a sequence of (shuffle block id, shuffle block size) tuples
-   *         describing the shuffle blocks that are stored at that block manager.
-   */
+  // For testing
   def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
       : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
     getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
@@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
    *         describing the shuffle blocks that are stored at that block manager.
    */
   def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
-      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
-    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
-    val statuses = getStatuses(shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets mutated in place
-    statuses.synchronized {
-      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
-    }
-  }
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])]
 
   /**
-   * Return statistics about all of the outputs for a given shuffle.
+   * Deletes map output status information for the specified shuffle stage.
    */
-  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
-    val statuses = getStatuses(dep.shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets mutated in place
-    statuses.synchronized {
-      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
-        }
-      }
-      new MapOutputStatistics(dep.shuffleId, totalSizes)
-    }
-  }
-
-  /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
-   * on this array when reading it, because on the driver, we may be changing it in place.
-   *
-   * (It would be nice to remove this restriction in the future.)
-   */
-  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
-      val startTime = System.currentTimeMillis
-      var fetchedStatuses: Array[MapStatus] = null
-      fetching.synchronized {
-        // Someone else is fetching it; wait for them to be done
-        while (fetching.contains(shuffleId)) {
-          try {
-            fetching.wait()
-          } catch {
-            case e: InterruptedException =>
-          }
-        }
-
-        // Either while we waited the fetch happened successfully, or
-        // someone fetched it in between the get and the fetching.synchronized.
-        fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          // We have to do the fetch, get others to wait for us.
-          fetching += shuffleId
-        }
-      }
+  def unregisterShuffle(shuffleId: Int): Unit
 
-      if (fetchedStatuses == null) {
-        // We won the race to fetch the statuses; do so
-        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-        // This try-finally prevents hangs due to timeouts:
-        try {
-          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
-        } finally {
-          fetching.synchronized {
-            fetching -= shuffleId
-            fetching.notifyAll()
-          }
-        }
-      }
-      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
-        s"${System.currentTimeMillis - startTime} ms")
-
-      if (fetchedStatuses != null) {
-        return fetchedStatuses
-      } else {
-        logError("Missing all output locations for shuffle " + shuffleId)
-        throw new MetadataFetchFailedException(
-          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
-      }
-    } else {
-      return statuses
-    }
-  }
-
-  /** Called to get current epoch number. */
-  def getEpoch: Long = {
-    epochLock.synchronized {
-      return epoch
-    }
-  }
-
-  /**
-   * Called from executors to update the epoch number, potentially clearing old outputs
-   * because of a fetch failure. Each executor task calls this with the latest epoch
-   * number on the driver at the time it was created.
-   */
-  def updateEpoch(newEpoch: Long) {
-    epochLock.synchronized {
-      if (newEpoch > epoch) {
-        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
-        epoch = newEpoch
-        mapStatuses.clear()
-      }
-    }
-  }
-
-  /** Unregister shuffle data. */
-  def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-  }
-
-  /** Stop the tracker. */
-  def stop() { }
+  def stop() {}
 }
 
 /**
- * MapOutputTracker for the driver.
+ * Driver-side class that keeps track of the location of the map output of a stage.
+ *
+ * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics
+ * for performing locality-aware reduce task scheduling.
+ *
+ * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine
+ * which tasks need to be run.
  */
-private[spark] class MapOutputTrackerMaster(conf: SparkConf,
-    broadcastManager: BroadcastManager, isLocal: Boolean)
+private[spark] class MapOutputTrackerMaster(
+    conf: SparkConf,
+    broadcastManager: BroadcastManager,
+    isLocal: Boolean)
   extends MapOutputTracker(conf) {
 
-  /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
-  private var cacheEpoch = epoch
-
   // The size at which we use Broadcast to send the map output statuses to the executors
   private val minSizeForBroadcast =
     conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
@@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
   // can be read locally, but may lead to more delay in scheduling if those locations are busy.
   private val REDUCER_PREF_LOCS_FRACTION = 0.2
 
-  // HashMaps for storing mapStatuses and cached serialized statuses in the driver.
+  // HashMap for storing shuffleStatuses in the driver.
   // Statuses are dropped only by explicit de-registering.
-  protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
-  private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala
+  private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala
 
   private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
-  // Kept in sync with cachedSerializedStatuses explicitly
-  // This is required so that the Broadcast variable remains in scope until we remove
-  // the shuffleId explicitly or implicitly.
-  private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]()
-
-  // This is to prevent multiple serializations of the same shuffle - which happens when
-  // there is a request storm when shuffle start.
-  private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
-
   // requests for map output statuses
   private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
 
@@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
             val hostPort = context.senderAddress.hostPort
             logDebug("Handling request to send map output locations for shuffle " + shuffleId +
               " to " + hostPort)
-            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
-            context.reply(mapOutputStatuses)
+            val shuffleStatus = shuffleStatuses.get(shuffleId).head
+            context.reply(
+              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
           } catch {
             case NonFatal(e) => logError(e.getMessage, e)
           }
@@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
   /** A poison endpoint that indicates MessageLoop should exit its message loop. */
   private val PoisonPill = new GetMapOutputMessage(-99, null)
 
-  // Exposed for testing
-  private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size
+  // Used only in unit tests.
+  private[spark] def getNumCachedSerializedBroadcast: Int = {
+    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
+  }
 
   def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
     }
-    // add in advance
-    shuffleIdLocks.putIfAbsent(shuffleId, new Object())
   }
 
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
-    val array = mapStatuses(shuffleId)
-    array.synchronized {
-      array(mapId) = status
-    }
-  }
-
-  /** Register multiple map output information for the given shuffle */
-  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
-    mapStatuses.put(shuffleId, statuses.clone())
-    if (changeEpoch) {
-      incrementEpoch()
-    }
+    shuffleStatuses(shuffleId).addMapOutput(mapId, status)
   }
 
   /** Unregister map output information of the given shuffle, mapper and block manager */
   def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
-    val arrayOpt = mapStatuses.get(shuffleId)
-    if (arrayOpt.isDefined && arrayOpt.get != null) {
-      val array = arrayOpt.get
-      array.synchronized {
-        if (array(mapId) != null && array(mapId).location == bmAddress) {
-          array(mapId) = null
-        }
-      }
-      incrementEpoch()
-    } else {
-      throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeMapOutput(mapId, bmAddress)
+        incrementEpoch()
+      case None =>
+        throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
     }
   }
 
   /** Unregister shuffle data */
-  override def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-    cachedSerializedStatuses.remove(shuffleId)
-    cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v))
-    shuffleIdLocks.remove(shuffleId)
+  def unregisterShuffle(shuffleId: Int) {
+    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
+      shuffleStatus.invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all shuffle outputs associated with this executor. Note that this will also remove
+   * outputs which are served by an external shuffle server (if one exists), as they are still
+   * registered with this execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = {
+    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
+    incrementEpoch()
   }
 
   /** Check if the given shuffle is being tracked */
-  def containsShuffle(shuffleId: Int): Boolean = {
-    cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
+  def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
+
+  def getNumAvailableOutputs(shuffleId: Int): Int = {
+    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None
+   * if the MapOutputTrackerMaster doesn't know about this shuffle.
+   */
+  def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
+    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
+  }
+
+  /**
+   * Return statistics about all of the outputs for a given shuffle.
+   */
+  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
+    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
+      for (s <- statuses) {
+        for (i <- 0 until totalSizes.length) {
+          totalSizes(i) += s.getSizeForBlock(i)
+        }
+      }
+      new MapOutputStatistics(dep.shuffleId, totalSizes)
+    }
   }
 
   /**
@@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
       fractionThreshold: Double)
     : Option[Array[BlockManagerId]] = {
 
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses != null) {
-      statuses.synchronized {
+    val shuffleStatus = shuffleStatuses.get(shuffleId).orNull
+    if (shuffleStatus != null) {
+      shuffleStatus.withMapStatuses { statuses =>
         if (statuses.nonEmpty) {
           // HashMap to add up sizes of all blocks at the same location
           val locs = new HashMap[BlockManagerId, Long]
@@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
     }
   }
 
-  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
-    if (null != bcast) {
-      broadcastManager.unbroadcast(bcast.id,
-        removeFromDriver = true, blocking = false)
+  /** Called to get current epoch number. */
+  def getEpoch: Long = {
+    epochLock.synchronized {
+      return epoch
     }
   }
 
-  private def clearCachedBroadcast(): Unit = {
-    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
-    cachedSerializedBroadcast.clear()
-  }
-
-  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
-    var statuses: Array[MapStatus] = null
-    var retBytes: Array[Byte] = null
-    var epochGotten: Long = -1
-
-    // Check to see if we have a cached version, returns true if it does
-    // and has side effect of setting retBytes.  If not returns false
-    // with side effect of setting statuses
-    def checkCachedStatuses(): Boolean = {
-      epochLock.synchronized {
-        if (epoch > cacheEpoch) {
-          cachedSerializedStatuses.clear()
-          clearCachedBroadcast()
-          cacheEpoch = epoch
-        }
-        cachedSerializedStatuses.get(shuffleId) match {
-          case Some(bytes) =>
-            retBytes = bytes
-            true
-          case None =>
-            logDebug("cached status not found for : " + shuffleId)
-            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
-            epochGotten = epoch
-            false
-        }
-      }
-    }
-
-    if (checkCachedStatuses()) return retBytes
-    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
-    if (null == shuffleIdLock) {
-      val newLock = new Object()
-      // in general, this condition should be false - but good to be paranoid
-      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
-      shuffleIdLock = if (null != prevLock) prevLock else newLock
-    }
-    // synchronize so we only serialize/broadcast it once since multiple threads call
-    // in parallel
-    shuffleIdLock.synchronized {
-      // double check to make sure someone else didn't serialize and cache the same
-      // mapstatus while we were waiting on the synchronize
-      if (checkCachedStatuses()) return retBytes
-
-      // If we got here, we failed to find the serialized locations in the cache, so we pulled
-      // out a snapshot of the locations as "statuses"; let's serialize and return that
-      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
-        isLocal, minSizeForBroadcast)
-      logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
-      // Add them into the table only if the epoch hasn't changed while we were working
-      epochLock.synchronized {
-        if (epoch == epochGotten) {
-          cachedSerializedStatuses(shuffleId) = bytes
-          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
-        } else {
-          logInfo("Epoch changed, not caching!")
-          removeBroadcast(bcast)
+  // This method is only called in local-mode.
+  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
+    shuffleStatuses.get(shuffleId) match {
+      case Some (shuffleStatus) =>
+        shuffleStatus.withMapStatuses { statuses =>
+          MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
         }
-      }
-      bytes
+      case None =>
+        Seq.empty
     }
   }
 
@@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
     mapOutputRequests.offer(PoisonPill)
     threadpool.shutdown()
     sendTracker(StopMapOutputTracker)
-    mapStatuses.clear()
     trackerEndpoint = null
-    cachedSerializedStatuses.clear()
-    clearCachedBroadcast()
-    shuffleIdLocks.clear()
+    shuffleStatuses.clear()
   }
 }
 
 /**
- * MapOutputTracker for the executors, which fetches map output information from the driver's
- * MapOutputTrackerMaster.
+ * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
+ * Note that this is not used in local-mode; instead, local-mode Executors access the
+ * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon
+ * superclass).
  */
 private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
-  protected val mapStatuses: Map[Int, Array[MapStatus]] =
+
+  val mapStatuses: Map[Int, Array[MapStatus]] =
     new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
+
+  /** Remembers which map output locations are currently being fetched on an executor. */
+  private val fetching = new HashSet[Int]
+
+  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
+    val statuses = getStatuses(shuffleId)
+    try {
+      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+    } catch {
+      case e: MetadataFetchFailedException =>
+        // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+        mapStatuses.clear()
+        throw e
+    }
+  }
+
+  /**
+   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
+   * on this array when reading it, because on the driver, we may be changing it in place.
+   *
+   * (It would be nice to remove this restriction in the future.)
+   */
+  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
+    val statuses = mapStatuses.get(shuffleId).orNull
+    if (statuses == null) {
+      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+      val startTime = System.currentTimeMillis
+      var fetchedStatuses: Array[MapStatus] = null
+      fetching.synchronized {
+        // Someone else is fetching it; wait for them to be done
+        while (fetching.contains(shuffleId)) {
+          try {
+            fetching.wait()
+          } catch {
+            case e: InterruptedException =>
+          }
+        }
+
+        // Either while we waited the fetch happened successfully, or
+        // someone fetched it in between the get and the fetching.synchronized.
+        fetchedStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedStatuses == null) {
+          // We have to do the fetch, get others to wait for us.
+          fetching += shuffleId
+        }
+      }
+
+      if (fetchedStatuses == null) {
+        // We won the race to fetch the statuses; do so
+        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        // This try-finally prevents hangs due to timeouts:
+        try {
+          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
+          logInfo("Got the output locations")
+          mapStatuses.put(shuffleId, fetchedStatuses)
+        } finally {
+          fetching.synchronized {
+            fetching -= shuffleId
+            fetching.notifyAll()
+          }
+        }
+      }
+      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        s"${System.currentTimeMillis - startTime} ms")
+
+      if (fetchedStatuses != null) {
+        fetchedStatuses
+      } else {
+        logError("Missing all output locations for shuffle " + shuffleId)
+        throw new MetadataFetchFailedException(
+          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
+      }
+    } else {
+      statuses
+    }
+  }
+
+
+  /** Unregister shuffle data. */
+  def unregisterShuffle(shuffleId: Int): Unit = {
+    mapStatuses.remove(shuffleId)
+  }
+
+  /**
+   * Called from executors to update the epoch number, potentially clearing old outputs
+   * because of a fetch failure. Each executor task calls this with the latest epoch
+   * number on the driver at the time it was created.
+   */
+  def updateEpoch(newEpoch: Long): Unit = {
+    epochLock.synchronized {
+      if (newEpoch > epoch) {
+        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
+        epoch = newEpoch
+        mapStatuses.clear()
+      }
+    }
+  }
 }
 
 private[spark] object MapOutputTracker extends Logging {
@@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging {
    *         and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
    *         describing the shuffle blocks that are stored at that block manager.
    */
-  private def convertMapStatuses(
+  def convertMapStatuses(
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/Partitioner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index f83f527..93a9337 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -32,6 +32,9 @@ import org.apache.spark.util.random.SamplingUtils
 /**
  * An object that defines how the elements in a key-value pair RDD are partitioned by key.
  * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
+ *
+ * Note that, partitioner must be deterministic, i.e. it must return the same partition id given
+ * the same partition key.
  */
 abstract class Partitioner extends Serializable {
   def numPartitions: Int

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 36b1743..47c51c0 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -325,8 +325,14 @@ private[spark] class Executor(
           throw new TaskKilledException(killReason.get)
         }
 
-        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
-        env.mapOutputTracker.updateEpoch(task.epoch)
+        // The purpose of updating the epoch here is to invalidate executor map output status cache
+        // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
+        // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
+        // we don't need to make any special calls here.
+        if (!isLocal) {
+          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
+          env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
+        }
 
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index e4587c9..15128f0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -23,11 +23,22 @@ import org.apache.spark.{Partition, TaskContext}
 
 /**
  * An RDD that applies the provided function to every partition of the parent RDD.
+ *
+ * @param prev the parent RDD.
+ * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to
+ *          an output iterator.
+ * @param preservesPartitioning Whether the input function preserves the partitioner, which should
+ *                              be `false` unless `prev` is a pair RDD and the input function
+ *                              doesn't modify the keys.
+ * @param isOrderSensitive whether or not the function is order-sensitive. If it's order
+ *                         sensitive, it may return totally different result when the input order
+ *                         is changed. Mostly stateful functions are order-sensitive.
  */
 private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     var prev: RDD[T],
     f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
-    preservesPartitioning: Boolean = false)
+    preservesPartitioning: Boolean = false,
+    isOrderSensitive: Boolean = false)
   extends RDD[U](prev) {
 
   override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -41,4 +52,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     super.clearDependencies()
     prev = null
   }
+
+  override protected def getOutputDeterministicLevel = {
+    if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) {
+      DeterministicLevel.INDETERMINATE
+    } else {
+      super.getOutputDeterministicLevel
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 102836d..4ff0f83 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -461,8 +461,9 @@ abstract class RDD[T: ClassTag](
 
       // include a shuffle step so that our upstream tasks are still distributed
       new CoalescedRDD(
-        new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
-        new HashPartitioner(numPartitions)),
+        new ShuffledRDD[Int, T, T](
+          mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true),
+          new HashPartitioner(numPartitions)),
         numPartitions,
         partitionCoalescer).values
     } else {
@@ -806,16 +807,21 @@ abstract class RDD[T: ClassTag](
    * serializable and don't require closure cleaning.
    *
    * @param preservesPartitioning indicates whether the input function preserves the partitioner,
-   * which should be `false` unless this is a pair RDD and the input function doesn't modify
-   * the keys.
+   *                              which should be `false` unless this is a pair RDD and the input
+   *                              function doesn't modify the keys.
+   * @param isOrderSensitive whether or not the function is order-sensitive. If it's order
+   *                         sensitive, it may return totally different result when the input order
+   *                         is changed. Mostly stateful functions are order-sensitive.
    */
   private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
       f: (Int, Iterator[T]) => Iterator[U],
-      preservesPartitioning: Boolean = false): RDD[U] = withScope {
+      preservesPartitioning: Boolean = false,
+      isOrderSensitive: Boolean = false): RDD[U] = withScope {
     new MapPartitionsRDD(
       this,
       (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
-      preservesPartitioning)
+      preservesPartitioning = preservesPartitioning,
+      isOrderSensitive = isOrderSensitive)
   }
 
   /**
@@ -1635,6 +1641,16 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
+   * Return whether this RDD is reliably checkpointed and materialized.
+   */
+  private[rdd] def isReliablyCheckpointed: Boolean = {
+    checkpointData match {
+      case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true
+      case _ => false
+    }
+  }
+
+  /**
    * Gets the name of the directory to which this RDD was checkpointed.
    * This is not defined if the RDD is checkpointed locally.
    */
@@ -1838,6 +1854,63 @@ abstract class RDD[T: ClassTag](
   def toJavaRDD() : JavaRDD[T] = {
     new JavaRDD(this)(elementClassTag)
   }
+
+  /**
+   * Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]]
+   * for the definition.
+   *
+   * By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For
+   * RDDs with parents, we will generate a deterministic level candidate per parent according to
+   * the dependency. The deterministic level of the current RDD is the deterministic level
+   * candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to
+   * provide custom logic of calculating output deterministic level.
+   */
+  // TODO: make it public so users can set deterministic level to their custom RDDs.
+  // TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for
+  // different partitions.
+  private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = {
+    if (isReliablyCheckpointed) {
+      DeterministicLevel.DETERMINATE
+    } else {
+      getOutputDeterministicLevel
+    }
+  }
+
+  @DeveloperApi
+  protected def getOutputDeterministicLevel: DeterministicLevel.Value = {
+    val deterministicLevelCandidates = dependencies.map {
+      // The shuffle is not really happening, treat it like narrow dependency and assume the output
+      // deterministic level of current RDD is same as parent.
+      case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) =>
+        dep.rdd.outputDeterministicLevel
+
+      case dep: ShuffleDependency[_, _, _] =>
+        if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
+          // If map output was indeterminate, shuffle output will be indeterminate as well
+          DeterministicLevel.INDETERMINATE
+        } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) {
+          // if aggregator specified (and so unique keys) and key ordering specified - then
+          // consistent ordering.
+          DeterministicLevel.DETERMINATE
+        } else {
+          // In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and
+          // the arrival order of these shuffle blocks are totally random. Even if the parent map
+          // RDD is DETERMINATE, the reduce RDD is always UNORDERED.
+          DeterministicLevel.UNORDERED
+        }
+
+      // For narrow dependency, assume the output deterministic level of current RDD is same as
+      // parent.
+      case dep => dep.rdd.outputDeterministicLevel
+    }
+
+    if (deterministicLevelCandidates.isEmpty) {
+      // By default we assume the root RDD is determinate.
+      DeterministicLevel.DETERMINATE
+    } else {
+      deterministicLevelCandidates.maxBy(_.id)
+    }
+  }
 }
 
 
@@ -1891,3 +1964,18 @@ object RDD {
     new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
   }
 }
+
+/**
+ * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how
+ * the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels:
+ * 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun.
+ * 2. UNORDERED: The RDD output is always the same data set but the order can be different
+ *               after a rerun.
+ * 3. INDETERMINATE. The RDD output can be different after a rerun.
+ *
+ * Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output
+ * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE.
+ */
+private[spark] object DeterministicLevel extends Enumeration {
+  val DETERMINATE, UNORDERED, INDETERMINATE = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 099bc2e..cb6cdcd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
 import org.apache.spark.rpc.RpcTimeout
 import org.apache.spark.storage._
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -328,25 +328,14 @@ class DAGScheduler(
     val numTasks = rdd.partitions.length
     val parents = getOrCreateParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)
+    val stage = new ShuffleMapStage(
+      id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
 
     stageIdToStage(id) = stage
     shuffleIdToMapStage(shuffleDep.shuffleId) = stage
     updateJobIdStageIdMaps(jobId, stage)
 
-    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
-      // A previously run stage generated partitions for this shuffle, so for each output
-      // that's still available, copy information about that output location to the new stage
-      // (so we don't unnecessarily re-compute that data).
-      val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
-      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      (0 until locs.length).foreach { i =>
-        if (locs(i) ne null) {
-          // locs(i) will be null if missing
-          stage.addOutputLoc(i, locs(i))
-        }
-      }
-    } else {
+    if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
       // Kind of ugly: need to register RDDs with the cache and map output tracker here
       // since we can't do it in the RDD constructor because # of partitions is unknown
       logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1240,7 +1229,8 @@ class DAGScheduler(
               // The epoch of the task is acceptable (i.e., the task was launched after the most
               // recent failure we're aware of for the executor), so mark the task's output as
               // available.
-              shuffleStage.addOutputLoc(smt.partitionId, status)
+              mapOutputTracker.registerMapOutput(
+                shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
               // Remove the task's partition from pending partitions. This may have already been
               // done above, but will not have been done yet in cases where the task attempt was
               // from an earlier attempt of the stage (i.e., not the attempt that's currently
@@ -1257,16 +1247,14 @@ class DAGScheduler(
               logInfo("waiting: " + waitingStages)
               logInfo("failed: " + failedStages)
 
-              // We supply true to increment the epoch number here in case this is a
-              // recomputation of the map outputs. In that case, some nodes may have cached
-              // locations with holes (from when we detected the error) and will need the
-              // epoch incremented to refetch them.
-              // TODO: Only increment the epoch number if this is not the first time
-              //       we registered these map outputs.
-              mapOutputTracker.registerMapOutputs(
-                shuffleStage.shuffleDep.shuffleId,
-                shuffleStage.outputLocInMapOutputTrackerFormat(),
-                changeEpoch = true)
+              // This call to increment the epoch may not be strictly necessary, but it is retained
+              // for now in order to minimize the changes in behavior from an earlier version of the
+              // code. This existing behavior of always incrementing the epoch following any
+              // successful shuffle map stage completion may have benefits by causing unneeded
+              // cached map outputs to be cleaned up earlier on executors. In the future we can
+              // consider removing this call, but this will require some extra investigation.
+              // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
+              mapOutputTracker.incrementEpoch()
 
               clearCacheLocs()
 
@@ -1344,6 +1332,63 @@ class DAGScheduler(
             failedStages += failedStage
             failedStages += mapStage
             if (noResubmitEnqueued) {
+              // If the map stage is INDETERMINATE, which means the map tasks may return
+              // different result when re-try, we need to re-try all the tasks of the failed
+              // stage and its succeeding stages, because the input data will be changed after the
+              // map tasks are re-tried.
+              // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
+              // guaranteed to be determinate, so the input data of the reducers will not change
+              // even if the map tasks are re-tried.
+              if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
+                // It's a little tricky to find all the succeeding stages of `failedStage`, because
+                // each stage only know its parents not children. Here we traverse the stages from
+                // the leaf nodes (the result stages of active jobs), and rollback all the stages
+                // in the stage chains that connect to the `failedStage`. To speed up the stage
+                // traversing, we collect the stages to rollback first. If a stage needs to
+                // rollback, all its succeeding stages need to rollback to.
+                val stagesToRollback = scala.collection.mutable.HashSet(failedStage)
+
+                def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                  if (stagesToRollback.contains(stageChain.head)) {
+                    stageChain.drop(1).foreach(s => stagesToRollback += s)
+                  } else {
+                    stageChain.head.parents.foreach { s =>
+                      collectStagesToRollback(s :: stageChain)
+                    }
+                  }
+                }
+
+                def generateErrorMessage(stage: Stage): String = {
+                  "A shuffle map stage with indeterminate output was failed and retried. " +
+                    s"However, Spark cannot rollback the $stage to re-process the input data, " +
+                    "and has to fail this job. Please eliminate the indeterminacy by " +
+                    "checkpointing the RDD before repartition and try again."
+                }
+
+                activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))
+
+                stagesToRollback.foreach {
+                  case mapStage: ShuffleMapStage =>
+                    val numMissingPartitions = mapStage.findMissingPartitions().length
+                    if (numMissingPartitions < mapStage.numTasks) {
+                      // TODO: support to rollback shuffle files.
+                      // Currently the shuffle writing is "first write wins", so we can't re-run a
+                      // shuffle map stage and overwrite existing shuffle files. We have to finish
+                      // SPARK-8029 first.
+                      abortStage(mapStage, generateErrorMessage(mapStage), None)
+                    }
+
+                  case resultStage: ResultStage if resultStage.activeJob.isDefined =>
+                    val numMissingPartitions = resultStage.findMissingPartitions().length
+                    if (numMissingPartitions < resultStage.numTasks) {
+                      // TODO: support to rollback result tasks.
+                      abortStage(resultStage, generateErrorMessage(resultStage), None)
+                    }
+
+                  case _ =>
+                }
+              }
+
               // We expect one executor failure to trigger many FetchFailures in rapid succession,
               // but all of those task failures can typically be handled by a single resubmission of
               // the failed stage.  We avoid flooding the scheduler's event queue with resubmit
@@ -1367,7 +1412,6 @@ class DAGScheduler(
           }
           // Mark the map whose fetch failed as broken in the map stage
           if (mapId != -1) {
-            mapStage.removeOutputLoc(mapId, bmAddress)
             mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
           }
 
@@ -1416,17 +1460,7 @@ class DAGScheduler(
 
       if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
         logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
-        // TODO: This will be really slow if we keep accumulating shuffle map stages
-        for ((shuffleId, stage) <- shuffleIdToMapStage) {
-          stage.removeOutputsOnExecutor(execId)
-          mapOutputTracker.registerMapOutputs(
-            shuffleId,
-            stage.outputLocInMapOutputTrackerFormat(),
-            changeEpoch = true)
-        }
-        if (shuffleIdToMapStage.isEmpty) {
-          mapOutputTracker.incrementEpoch()
-        }
+        mapOutputTracker.removeOutputsOnExecutor(execId)
         clearCacheLocs()
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index db4d9ef..05f650f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.HashSet
 
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.CallSite
 
 /**
@@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage(
     parents: List[Stage],
     firstJobId: Int,
     callSite: CallSite,
-    val shuffleDep: ShuffleDependency[_, _, _])
+    val shuffleDep: ShuffleDependency[_, _, _],
+    mapOutputTrackerMaster: MapOutputTrackerMaster)
   extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
 
   private[this] var _mapStageJobs: List[ActiveJob] = Nil
 
-  private[this] var _numAvailableOutputs: Int = 0
-
   /**
    * Partitions that either haven't yet been computed, or that were computed on an executor
    * that has since been lost, so should be re-computed.  This variable is used by the
@@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage(
    */
   val pendingPartitions = new HashSet[Int]
 
-  /**
-   * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
-   * and each value in the array is the list of possible [[MapStatus]] for a partition
-   * (a single task might run multiple times).
-   */
-  private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
-
   override def toString: String = "ShuffleMapStage " + id
 
   /**
@@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage(
   /**
    * Number of partitions that have shuffle outputs.
    * When this reaches [[numPartitions]], this map stage is ready.
-   * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
    */
-  def numAvailableOutputs: Int = _numAvailableOutputs
+  def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
 
   /**
    * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
-   * This should be the same as `outputLocs.contains(Nil)`.
    */
-  def isAvailable: Boolean = _numAvailableOutputs == numPartitions
+  def isAvailable: Boolean = numAvailableOutputs == numPartitions
 
   /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
   override def findMissingPartitions(): Seq[Int] = {
-    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
-    missing
-  }
-
-  def addOutputLoc(partition: Int, status: MapStatus): Unit = {
-    val prevList = outputLocs(partition)
-    outputLocs(partition) = status :: prevList
-    if (prevList == Nil) {
-      _numAvailableOutputs += 1
-    }
-  }
-
-  def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
-    val prevList = outputLocs(partition)
-    val newList = prevList.filterNot(_.location == bmAddress)
-    outputLocs(partition) = newList
-    if (prevList != Nil && newList == Nil) {
-      _numAvailableOutputs -= 1
-    }
-  }
-
-  /**
-   * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
-   * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
-   * that position is filled with null.
-   */
-  def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
-    outputLocs.map(_.headOption.orNull)
-  }
-
-  /**
-   * Removes all shuffle outputs associated with this executor. Note that this will also remove
-   * outputs which are served by an external shuffle server (if one exists), as they are still
-   * registered with this execId.
-   */
-  def removeOutputsOnExecutor(execId: String): Unit = {
-    var becameUnavailable = false
-    for (partition <- 0 until numPartitions) {
-      val prevList = outputLocs(partition)
-      val newList = prevList.filterNot(_.location.executorId == execId)
-      outputLocs(partition) = newList
-      if (prevList != Nil && newList == Nil) {
-        becameUnavailable = true
-        _numAvailableOutputs -= 1
-      }
-    }
-    if (becameUnavailable) {
-      logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
-        this, execId, _numAvailableOutputs, numPartitions, isAvailable))
-    }
+    mapOutputTrackerMaster
+      .findMissingPartitions(shuffleDep.shuffleId)
+      .getOrElse(0 until numPartitions)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f8c62b4..bc0d470 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
 
   var backend: SchedulerBackend = null
 
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+  val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
 
   private var schedulableBuilder: SchedulableBuilder = null
   // default scheduler is FIFO

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ca94fd1..82b6fd1 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -138,21 +138,21 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
 
     masterTracker.registerShuffle(10, 1)
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
+    // This is expected to fail because no outputs have been registered for the shuffle.
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
 
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     masterTracker.registerMapOutput(10, 0, MapStatus(
       BlockManagerId("a", "hostA", 1000), Array(1000L)))
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
     assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
+    val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
-    masterTracker.incrementEpoch()
+    assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 3b564df..62c40d1 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -333,6 +333,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
     val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+    mapTrackerMaster.registerShuffle(0, 1)
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
@@ -367,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
 
     // register one of the map outputs -- doesn't matter which one
     mapOutput1.foreach { case mapStatus =>
-      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+      mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 2b18ebe..571c6bb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     sc = new SparkContext(conf)
     val scheduler = mock[TaskSchedulerImpl]
     when(scheduler.sc).thenReturn(sc)
-    when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker)
+    when(scheduler.mapOutputTracker).thenReturn(
+      SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])
     scheduler
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 9112065..1fff0d0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
 import org.apache.spark.broadcast.BroadcastManager
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -56,6 +56,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
 
 }
 
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+    Iterator.empty
+}
+
 /**
  * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
  * preferredLocations (if any) that are passed to them. They are deliberately not executable
@@ -70,7 +84,8 @@ class MyRDD(
     numPartitions: Int,
     dependencies: List[Dependency[_]],
     locations: Seq[Seq[String]] = Nil,
-    @(transient @param) tracker: MapOutputTrackerMaster = null)
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
   extends RDD[(Int, Int)](sc, dependencies) with Serializable {
 
   override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
@@ -80,6 +95,10 @@ class MyRDD(
     override def index: Int = i
   }).toArray
 
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+  }
+
   override def getPreferredLocations(partition: Partition): Seq[String] = {
     if (locations.isDefinedAt(partition.index)) {
       locations(partition.index)
@@ -2307,6 +2326,152 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     }
   }
 
+  test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
+    val shuffleId1 = shuffleDep1.shuffleId
+    val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
+    val shuffleId2 = shuffleDep2.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish the first shuffle map stage.
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+    // Finish the second shuffle map stage.
+    complete(taskSets(1), Seq(
+      (Success, makeMapStatus("hostC", 2)),
+      (Success, makeMapStatus("hostD", 2))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+
+    // The first task of the final stage failed with fetch failure
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"),
+      null))
+
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.length == 2)
+    // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
+    }.head.findMissingPartitions() == Seq(0))
+    // The result stage is still waiting for its 2 tasks to complete
+    assert(failedStages.collect {
+      case stage: ResultStage => stage
+    }.head.findMissingPartitions() == Seq(0, 1))
+
+    scheduler.resubmitFailedStages()
+
+    // The first task of the `shuffleMapRdd2` failed with fetch failure
+    runEvent(makeCompletionEvent(
+      taskSets(3).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"),
+      null))
+
+    // The job should fail because Spark can't rollback the shuffle map stage.
+    assert(failure != null && failure.getMessage.contains("Spark cannot rollback"))
+  }
+
+  private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = {
+    val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+
+    // Finish the first task of the result stage
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(0), Success, 42,
+      Seq.empty, createFakeTaskInfoWithId(0)))
+
+    // Fail the second task with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(1),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null))
+
+    // The job should fail because Spark can't rollback the result stage.
+    assert(failure != null && failure.getMessage.contains("Spark cannot rollback"))
+  }
+
+  test("SPARK-23207: cannot rollback a result stage") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true)
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") {
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.localCheckpoint()
+    shuffleMapRdd.doCheckpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") {
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.localCheckpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = {
+    val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+
+    // Finish the first task of the result stage
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(0), Success, 42,
+      Seq.empty, createFakeTaskInfoWithId(0)))
+
+    // Fail the second task with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(1),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null))
+
+    assert(failure == null, "job should not fail")
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.length == 2)
+    // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId => stage
+    }.head.findMissingPartitions() == Seq(0))
+    // The first task of result stage remains completed.
+    assert(failedStages.collect {
+      case stage: ResultStage => stage
+    }.head.findMissingPartitions() == Seq(1))
+  }
+
+  test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") {
+    sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.checkpoint()
+    shuffleMapRdd.doCheckpoint()
+    assertResultStageNotRollbacked(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") {
+    sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.checkpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
    * Note that this checks only the host and not the executor ID.

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index c0ba513..4496afb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -247,6 +247,9 @@ object ShuffleExchange {
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
 
+    val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
+      newPartitioning.numPartitions > 1
+
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
       // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
       // otherwise a retry task may output different rows and thus lead to data loss.
@@ -256,9 +259,7 @@ object ShuffleExchange {
       //
       // Note that we don't perform local sort if the new partitioning has only 1 partition, under
       // that case all output rows go to the same partition.
-      val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) &&
-          newPartitioning.numPartitions > 1 &&
-          newPartitioning.isInstanceOf[RoundRobinPartitioning]) {
+      val newRdd = if (isRoundRobin && SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION)) {
         rdd.mapPartitionsInternal { iter =>
           val recordComparatorSupplier = new Supplier[RecordComparator] {
             override def get: RecordComparator = new RecordBinaryComparator()
@@ -294,17 +295,19 @@ object ShuffleExchange {
         rdd
       }
 
+      // round-robin function is order sensitive if we don't sort the input.
+      val isOrderSensitive = isRoundRobin && !SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION)
       if (needToCopyObjectsBeforeShuffle(part, serializer)) {
-        newRdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsWithIndexInternal((_, iter) => {
           val getPartitionKey = getPartitionKeyExtractor()
           iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
-        }
+        }, isOrderSensitive = isOrderSensitive)
       } else {
-        newRdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsWithIndexInternal((_, iter) => {
           val getPartitionKey = getPartitionKeyExtractor()
           val mutablePair = new MutablePair[Int, InternalRow]()
           iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
-        }
+        }, isOrderSensitive = isOrderSensitive)
       }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org