You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/07/26 19:09:06 UTC

spark git commit: [SPARK-24795][CORE] Implement barrier execution mode

Repository: spark
Updated Branches:
  refs/heads/master 5ed7660d1 -> e3486e1b9


[SPARK-24795][CORE] Implement barrier execution mode

## What changes were proposed in this pull request?

Propose new APIs and modify job/task scheduling to support barrier execution mode, which requires all tasks in a same barrier stage start at the same time, and retry all tasks in case some tasks fail in the middle. The barrier execution mode is useful for some ML/DL workloads.

The proposed API changes include:

- `RDDBarrier` that marks an RDD as barrier (Spark must launch all the tasks together for the current stage).
- `BarrierTaskContext` that support global sync of all tasks in a barrier stage, and provide extra `BarrierTaskInfo`s.

In DAGScheduler, we retry all tasks of a barrier stage in case some tasks fail in the middle, this is achieved by unregistering map outputs for a shuffleId (for ShuffleMapStage) or clear the finished partitions in an active job (for ResultStage).

## How was this patch tested?

Add `RDDBarrierSuite` to ensure we convert RDDs correctly;
Add new test cases in `DAGSchedulerSuite` to ensure we do task scheduling correctly;
Add new test cases in `SparkContextSuite` to ensure the barrier execution mode actually works (both under local mode and local cluster mode).
Add new test cases in `TaskSchedulerImplSuite` to ensure we schedule tasks for barrier taskSet together.

Author: Xingbo Jiang <xi...@databricks.com>

Closes #21758 from jiangxb1987/barrier-execution-mode.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e3486e1b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e3486e1b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e3486e1b

Branch: refs/heads/master
Commit: e3486e1b9556e00bc9c392a5b8440ab366780f9b
Parents: 5ed7660
Author: Xingbo Jiang <xi...@databricks.com>
Authored: Thu Jul 26 12:09:01 2018 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Thu Jul 26 12:09:01 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/BarrierTaskContext.scala   |  42 ++++++
 .../apache/spark/BarrierTaskContextImpl.scala   |  49 +++++++
 .../org/apache/spark/BarrierTaskInfo.scala      |  31 +++++
 .../org/apache/spark/MapOutputTracker.scala     |  12 ++
 .../org/apache/spark/rdd/MapPartitionsRDD.scala |  15 ++-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  27 +++-
 .../scala/org/apache/spark/rdd/RDDBarrier.scala |  52 ++++++++
 .../org/apache/spark/rdd/ShuffledRDD.scala      |   2 +
 .../org/apache/spark/scheduler/ActiveJob.scala  |   6 +
 .../apache/spark/scheduler/DAGScheduler.scala   | 131 ++++++++++++++++---
 .../org/apache/spark/scheduler/ResultTask.scala |   9 +-
 .../apache/spark/scheduler/ShuffleMapTask.scala |   7 +-
 .../org/apache/spark/scheduler/Stage.scala      |   8 +-
 .../scala/org/apache/spark/scheduler/Task.scala |  41 ++++--
 .../spark/scheduler/TaskDescription.scala       |   7 +-
 .../spark/scheduler/TaskSchedulerImpl.scala     |  66 ++++++++--
 .../apache/spark/scheduler/TaskSetManager.scala |   9 +-
 .../apache/spark/scheduler/WorkerOffer.scala    |   8 +-
 .../cluster/CoarseGrainedSchedulerBackend.scala |   6 +-
 .../scheduler/local/LocalSchedulerBackend.scala |   3 +-
 .../org/apache/spark/SparkContextSuite.scala    |  42 ++++++
 .../apache/spark/executor/ExecutorSuite.scala   |   1 +
 .../org/apache/spark/rdd/RDDBarrierSuite.scala  |  43 ++++++
 .../spark/scheduler/DAGSchedulerSuite.scala     |  58 ++++++++
 .../org/apache/spark/scheduler/FakeTask.scala   |  24 +++-
 .../spark/scheduler/TaskDescriptionSuite.scala  |   2 +
 .../scheduler/TaskSchedulerImplSuite.scala      |  34 +++++
 .../MesosFineGrainedSchedulerBackendSuite.scala |   2 +
 28 files changed, 673 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
new file mode 100644
index 0000000..4c35862
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.{Experimental, Since}
+
+/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
+trait BarrierTaskContext extends TaskContext {
+
+  /**
+   * :: Experimental ::
+   * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
+   * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
+   * stage have reached this routine.
+   */
+  @Experimental
+  @Since("2.4.0")
+  def barrier(): Unit
+
+  /**
+   * :: Experimental ::
+   * Returns the all task infos in this barrier stage, the task infos are ordered by partitionId.
+   */
+  @Experimental
+  @Since("2.4.0")
+  def getTaskInfos(): Array[BarrierTaskInfo]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala
new file mode 100644
index 0000000..8ac7057
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.Properties
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.metrics.MetricsSystem
+
+/** A [[BarrierTaskContext]] implementation. */
+private[spark] class BarrierTaskContextImpl(
+    override val stageId: Int,
+    override val stageAttemptNumber: Int,
+    override val partitionId: Int,
+    override val taskAttemptId: Long,
+    override val attemptNumber: Int,
+    override val taskMemoryManager: TaskMemoryManager,
+    localProperties: Properties,
+    @transient private val metricsSystem: MetricsSystem,
+    // The default value is only used in tests.
+    override val taskMetrics: TaskMetrics = TaskMetrics.empty)
+  extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
+      taskMemoryManager, localProperties, metricsSystem, taskMetrics)
+    with BarrierTaskContext {
+
+  // TODO SPARK-24817 implement global barrier.
+  override def barrier(): Unit = {}
+
+  override def getTaskInfos(): Array[BarrierTaskInfo] = {
+    val addressesStr = localProperties.getProperty("addresses", "")
+    addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala
new file mode 100644
index 0000000..ce2653d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.{Experimental, Since}
+
+
+/**
+ * :: Experimental ::
+ * Carries all task infos of a barrier task.
+ *
+ * @param address the IPv4 address(host:port) of the executor that a barrier task is running on
+ */
+@Experimental
+@Since("2.4.0")
+class BarrierTaskInfo(val address: String)

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 7364605..1c4fa4b 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  /** Unregister all map output information of the given shuffle. */
+  def unregisterAllMapOutput(shuffleId: Int) {
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeOutputsByFilter(x => true)
+        incrementEpoch()
+      case None =>
+        throw new SparkException(
+          s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.")
+    }
+  }
+
   /** Unregister shuffle data */
   def unregisterShuffle(shuffleId: Int) {
     shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index e4587c9..904d9c0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -23,11 +23,21 @@ import org.apache.spark.{Partition, TaskContext}
 
 /**
  * An RDD that applies the provided function to every partition of the parent RDD.
+ *
+ * @param prev the parent RDD.
+ * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to
+ *          an output iterator.
+ * @param preservesPartitioning Whether the input function preserves the partitioner, which should
+ *                              be `false` unless `prev` is a pair RDD and the input function
+ *                              doesn't modify the keys.
+ * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage
+ *                      containing at least one RDDBarrier shall be turned into a barrier stage.
  */
 private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     var prev: RDD[T],
     f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
-    preservesPartitioning: Boolean = false)
+    preservesPartitioning: Boolean = false,
+    isFromBarrier: Boolean = false)
   extends RDD[U](prev) {
 
   override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -41,4 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     super.clearDependencies()
     prev = null
   }
+
+  @transient protected lazy override val isBarrier_ : Boolean =
+    isFromBarrier || dependencies.exists(_.rdd.isBarrier())
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0574abd..cbc1143 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
 
 import org.apache.spark._
 import org.apache.spark.Partitioner._
-import org.apache.spark.annotation.{DeveloperApi, Since}
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
 import org.apache.spark.partial.BoundedDouble
@@ -1647,6 +1647,14 @@ abstract class RDD[T: ClassTag](
     }
   }
 
+  /**
+   * :: Experimental ::
+   * Indicates that Spark must launch the tasks together for the current stage.
+   */
+  @Experimental
+  @Since("2.4.0")
+  def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this))
+
   // =======================================================================
   // Other internal methods and fields
   // =======================================================================
@@ -1839,6 +1847,23 @@ abstract class RDD[T: ClassTag](
   def toJavaRDD() : JavaRDD[T] = {
     new JavaRDD(this)(elementClassTag)
   }
+
+  /**
+   * Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a
+   * barrier stage.
+   *
+   * An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from
+   * an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a
+   * [[ShuffledRDD]] indicates start of a new stage.
+   *
+   * A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the
+   * [[MapPartitionsRDD]] shall be marked as barrier.
+   */
+  private[spark] def isBarrier(): Boolean = isBarrier_
+
+  // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long
+  // RDD chain.
+  @transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier())
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
new file mode 100644
index 0000000..85565d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.BarrierTaskContext
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.{Experimental, Since}
+
+/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */
+class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
+
+  /**
+   * :: Experimental ::
+   * Maps partitions together with a provided BarrierTaskContext.
+   *
+   * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+   * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
+   */
+  @Experimental
+  @Since("2.4.0")
+  def mapPartitions[S: ClassTag](
+      f: (Iterator[T], BarrierTaskContext) => Iterator[S],
+      preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
+    val cleanedF = rdd.sparkContext.clean(f)
+    new MapPartitionsRDD(
+      rdd,
+      (context: TaskContext, index: Int, iter: Iterator[T]) =>
+        cleanedF(iter, context.asInstanceOf[BarrierTaskContext]),
+      preservesPartitioning,
+      isFromBarrier = true
+    )
+  }
+
+  /** TODO extra conf(e.g. timeout) */
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 26eaa9a..e8f9b27 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
     super.clearDependencies()
     prev = null
   }
+
+  private[spark] override def isBarrier(): Boolean = false
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
index 949e88f..6e4d062 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
@@ -60,4 +60,10 @@ private[spark] class ActiveJob(
   val finished = Array.fill[Boolean](numPartitions)(false)
 
   var numFinished = 0
+
+  /** Resets the status of all partitions in this stage so they are marked as not finished. */
+  def resetAllPartitions(): Unit = {
+    (0 until numPartitions).foreach(finished.update(_, false))
+    numFinished = 0
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f74425d..003d64f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1062,7 +1062,7 @@ class DAGScheduler(
             stage.pendingPartitions += id
             new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
               taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
-              Option(sc.applicationId), sc.applicationAttemptId)
+              Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
           }
 
         case stage: ResultStage =>
@@ -1072,7 +1072,8 @@ class DAGScheduler(
             val locs = taskIdToLocations(id)
             new ResultTask(stage.id, stage.latestInfo.attemptNumber,
               taskBinary, part, locs, id, properties, serializedTaskMetrics,
-              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
+              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
+              stage.rdd.isBarrier())
           }
       }
     } catch {
@@ -1311,17 +1312,6 @@ class DAGScheduler(
             }
         }
 
-      case Resubmitted =>
-        logInfo("Resubmitted " + task + ", so marking it as still running")
-        stage match {
-          case sms: ShuffleMapStage =>
-            sms.pendingPartitions += task.partitionId
-
-          case _ =>
-            assert(false, "TaskSetManagers should only send Resubmitted task statuses for " +
-              "tasks in ShuffleMapStages.")
-        }
-
       case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
         val failedStage = stageIdToStage(task.stageId)
         val mapStage = shuffleIdToMapStage(shuffleId)
@@ -1331,9 +1321,9 @@ class DAGScheduler(
             s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
             s"(attempt ${failedStage.latestInfo.attemptNumber}) running")
         } else {
-          failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
+          failedStage.failedAttemptIds.add(task.stageAttemptId)
           val shouldAbortStage =
-            failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
+            failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
             disallowStageRetryForTest
 
           // It is likely that we receive multiple FetchFailed for a single stage (because we have
@@ -1349,6 +1339,29 @@ class DAGScheduler(
               s"longer running")
           }
 
+          if (mapStage.rdd.isBarrier()) {
+            // Mark all the map as broken in the map stage, to ensure retry all the tasks on
+            // resubmitted stage attempt.
+            mapOutputTracker.unregisterAllMapOutput(shuffleId)
+          } else if (mapId != -1) {
+            // Mark the map whose fetch failed as broken in the map stage
+            mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+          }
+
+          if (failedStage.rdd.isBarrier()) {
+            failedStage match {
+              case failedMapStage: ShuffleMapStage =>
+                // Mark all the map as broken in the map stage, to ensure retry all the tasks on
+                // resubmitted stage attempt.
+                mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+
+              case failedResultStage: ResultStage =>
+                // Mark all the partitions of the result stage to be not finished, to ensure retry
+                // all the tasks on resubmitted stage attempt.
+                failedResultStage.activeJob.map(_.resetAllPartitions())
+            }
+          }
+
           if (shouldAbortStage) {
             val abortMessage = if (disallowStageRetryForTest) {
               "Fetch failure will not retry stage due to testing config"
@@ -1375,7 +1388,7 @@ class DAGScheduler(
               // simpler while not producing an overwhelming number of scheduler events.
               logInfo(
                 s"Resubmitting $mapStage (${mapStage.name}) and " +
-                s"$failedStage (${failedStage.name}) due to fetch failure"
+                  s"$failedStage (${failedStage.name}) due to fetch failure"
               )
               messageScheduler.schedule(
                 new Runnable {
@@ -1386,10 +1399,6 @@ class DAGScheduler(
               )
             }
           }
-          // Mark the map whose fetch failed as broken in the map stage
-          if (mapId != -1) {
-            mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
-          }
 
           // TODO: mark the executor as failed only if there were lots of fetch failures on it
           if (bmAddress != null) {
@@ -1411,6 +1420,76 @@ class DAGScheduler(
           }
         }
 
+      case failure: TaskFailedReason if task.isBarrier =>
+        // Also handle the task failed reasons here.
+        failure match {
+          case Resubmitted =>
+            handleResubmittedFailure(task, stage)
+
+          case _ => // Do nothing.
+        }
+
+        // Always fail the current stage and retry all the tasks when a barrier task fail.
+        val failedStage = stageIdToStage(task.stageId)
+        logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " +
+          "failed.")
+        val message = s"Stage failed because barrier task $task finished unsuccessfully. " +
+          failure.toErrorString
+        try {
+          // cancelTasks will fail if a SchedulerBackend does not implement killTask
+          taskScheduler.cancelTasks(stageId, interruptThread = false)
+        } catch {
+          case e: UnsupportedOperationException =>
+            // Cannot continue with barrier stage if failed to cancel zombie barrier tasks.
+            // TODO SPARK-24877 leave the zombie tasks and ignore their completion events.
+            logWarning(s"Could not cancel tasks for stage $stageId", e)
+            abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " +
+              s"$failedStage (${failedStage.name})", Some(e))
+        }
+        markStageAsFinished(failedStage, Some(message))
+
+        failedStage.failedAttemptIds.add(task.stageAttemptId)
+        // TODO Refactor the failure handling logic to combine similar code with that of
+        // FetchFailed.
+        val shouldAbortStage =
+          failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+            disallowStageRetryForTest
+
+        if (shouldAbortStage) {
+          val abortMessage = if (disallowStageRetryForTest) {
+            "Barrier stage will not retry stage due to testing config"
+          } else {
+            s"""$failedStage (${failedStage.name})
+               |has failed the maximum allowable number of
+               |times: $maxConsecutiveStageAttempts.
+               |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ")
+          }
+          abortStage(failedStage, abortMessage, None)
+        } else {
+          failedStage match {
+            case failedMapStage: ShuffleMapStage =>
+              // Mark all the map as broken in the map stage, to ensure retry all the tasks on
+              // resubmitted stage attempt.
+              mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+
+            case failedResultStage: ResultStage =>
+              // Mark all the partitions of the result stage to be not finished, to ensure retry
+              // all the tasks on resubmitted stage attempt.
+              failedResultStage.activeJob.map(_.resetAllPartitions())
+          }
+
+          // update failedStages and make sure a ResubmitFailedStages event is enqueued
+          failedStages += failedStage
+          logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " +
+            "failure.")
+          messageScheduler.schedule(new Runnable {
+            override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+          }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+        }
+
+      case Resubmitted =>
+        handleResubmittedFailure(task, stage)
+
       case _: TaskCommitDenied =>
         // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
 
@@ -1426,6 +1505,18 @@ class DAGScheduler(
     }
   }
 
+  private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = {
+    logInfo(s"Resubmitted $task, so marking it as still running.")
+    stage match {
+      case sms: ShuffleMapStage =>
+        sms.pendingPartitions += task.partitionId
+
+      case _ =>
+        throw new SparkException("TaskSetManagers should only send Resubmitted task " +
+          "statuses for tasks in ShuffleMapStages.")
+    }
+  }
+
   private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = {
     // Mark any map-stage jobs waiting on this stage as finished
     if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index e36c759..aafeae0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -48,7 +48,9 @@ import org.apache.spark.rdd.RDD
  * @param jobId id of the job this task belongs to
  * @param appId id of the app this task belongs to
  * @param appAttemptId attempt id of the app this task belongs to
-  */
+ * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks
+ *                  at the same time for a barrier stage.
+ */
 private[spark] class ResultTask[T, U](
     stageId: Int,
     stageAttemptId: Int,
@@ -60,9 +62,10 @@ private[spark] class ResultTask[T, U](
     serializedTaskMetrics: Array[Byte],
     jobId: Option[Int] = None,
     appId: Option[String] = None,
-    appAttemptId: Option[String] = None)
+    appAttemptId: Option[String] = None,
+    isBarrier: Boolean = false)
   extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics,
-    jobId, appId, appAttemptId)
+    jobId, appId, appAttemptId, isBarrier)
   with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 7a25c47..f2cd65f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -49,6 +49,8 @@ import org.apache.spark.shuffle.ShuffleWriter
  * @param jobId id of the job this task belongs to
  * @param appId id of the app this task belongs to
  * @param appAttemptId attempt id of the app this task belongs to
+ * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks
+ *                  at the same time for a barrier stage.
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
@@ -60,9 +62,10 @@ private[spark] class ShuffleMapTask(
     serializedTaskMetrics: Array[Byte],
     jobId: Option[Int] = None,
     appId: Option[String] = None,
-    appAttemptId: Option[String] = None)
+    appAttemptId: Option[String] = None,
+    isBarrier: Boolean = false)
   extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
-    serializedTaskMetrics, jobId, appId, appAttemptId)
+    serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
   with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 290fd07..26cca33 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -82,15 +82,15 @@ private[scheduler] abstract class Stage(
   private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)
 
   /**
-   * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these
-   * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure.
+   * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid
+   * endless retries if a stage keeps failing.
    * We keep track of each attempt ID that has failed to avoid recording duplicate failures if
    * multiple tasks from the same stage attempt fail (SPARK-5945).
    */
-  val fetchFailedAttemptIds = new HashSet[Int]
+  val failedAttemptIds = new HashSet[Int]
 
   private[scheduler] def clearFailures() : Unit = {
-    fetchFailedAttemptIds.clear()
+    failedAttemptIds.clear()
   }
 
   /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index f536fc2..89ff203 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -49,6 +49,8 @@ import org.apache.spark.util._
  * @param jobId id of the job this task belongs to
  * @param appId id of the app this task belongs to
  * @param appAttemptId attempt id of the app this task belongs to
+ * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks
+ *                  at the same time for a barrier stage.
  */
 private[spark] abstract class Task[T](
     val stageId: Int,
@@ -60,7 +62,8 @@ private[spark] abstract class Task[T](
       SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
     val jobId: Option[Int] = None,
     val appId: Option[String] = None,
-    val appAttemptId: Option[String] = None) extends Serializable {
+    val appAttemptId: Option[String] = None,
+    val isBarrier: Boolean = false) extends Serializable {
 
   @transient lazy val metrics: TaskMetrics =
     SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics))
@@ -77,16 +80,32 @@ private[spark] abstract class Task[T](
       attemptNumber: Int,
       metricsSystem: MetricsSystem): T = {
     SparkEnv.get.blockManager.registerTask(taskAttemptId)
-    context = new TaskContextImpl(
-      stageId,
-      stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
-      partitionId,
-      taskAttemptId,
-      attemptNumber,
-      taskMemoryManager,
-      localProperties,
-      metricsSystem,
-      metrics)
+    // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
+    // the stage is barrier.
+    context = if (isBarrier) {
+      new BarrierTaskContextImpl(
+        stageId,
+        stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
+        partitionId,
+        taskAttemptId,
+        attemptNumber,
+        taskMemoryManager,
+        localProperties,
+        metricsSystem,
+        metrics)
+    } else {
+      new TaskContextImpl(
+        stageId,
+        stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
+        partitionId,
+        taskAttemptId,
+        attemptNumber,
+        taskMemoryManager,
+        localProperties,
+        metricsSystem,
+        metrics)
+    }
+
     TaskContext.setTaskContext(context)
     taskThread = Thread.currentThread()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index c98b871..bb4a444 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -50,6 +50,7 @@ private[spark] class TaskDescription(
     val executorId: String,
     val name: String,
     val index: Int,    // Index within this task's TaskSet
+    val partitionId: Int,
     val addedFiles: Map[String, Long],
     val addedJars: Map[String, Long],
     val properties: Properties,
@@ -76,6 +77,7 @@ private[spark] object TaskDescription {
     dataOut.writeUTF(taskDescription.executorId)
     dataOut.writeUTF(taskDescription.name)
     dataOut.writeInt(taskDescription.index)
+    dataOut.writeInt(taskDescription.partitionId)
 
     // Write files.
     serializeStringLongMap(taskDescription.addedFiles, dataOut)
@@ -117,6 +119,7 @@ private[spark] object TaskDescription {
     val executorId = dataIn.readUTF()
     val name = dataIn.readUTF()
     val index = dataIn.readInt()
+    val partitionId = dataIn.readInt()
 
     // Read files.
     val taskFiles = deserializeStringLongMap(dataIn)
@@ -138,7 +141,7 @@ private[spark] object TaskDescription {
     // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
     val serializedTask = byteBuffer.slice()
 
-    new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
-      properties, serializedTask)
+    new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles,
+      taskJars, properties, serializedTask)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 56c0bf6..587ed4b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -274,7 +274,8 @@ private[spark] class TaskSchedulerImpl(
       maxLocality: TaskLocality,
       shuffledOffers: Seq[WorkerOffer],
       availableCpus: Array[Int],
-      tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = {
+      tasks: IndexedSeq[ArrayBuffer[TaskDescription]],
+      addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = {
     var launchedTask = false
     // nodes and executors that are blacklisted for the entire application have already been
     // filtered out by this point
@@ -291,6 +292,11 @@ private[spark] class TaskSchedulerImpl(
             executorIdToRunningTaskIds(execId).add(tid)
             availableCpus(i) -= CPUS_PER_TASK
             assert(availableCpus(i) >= 0)
+            // Only update hosts for a barrier task.
+            if (taskSet.isBarrier) {
+              // The executor address is expected to be non empty.
+              addressesWithDescs += (shuffledOffers(i).address.get -> task)
+            }
             launchedTask = true
           }
         } catch {
@@ -346,6 +352,7 @@ private[spark] class TaskSchedulerImpl(
     // Build a list of tasks to assign to each worker.
     val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK))
     val availableCpus = shuffledOffers.map(o => o.cores).toArray
+    val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum
     val sortedTaskSets = rootPool.getSortedTaskSetQueue
     for (taskSet <- sortedTaskSets) {
       logDebug("parentName: %s, name: %s, runningTasks: %s".format(
@@ -359,20 +366,55 @@ private[spark] class TaskSchedulerImpl(
     // of locality levels so that it gets a chance to launch local tasks on all of them.
     // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
     for (taskSet <- sortedTaskSets) {
-      var launchedAnyTask = false
-      var launchedTaskAtCurrentMaxLocality = false
-      for (currentMaxLocality <- taskSet.myLocalityLevels) {
-        do {
-          launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(
-            taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks)
-          launchedAnyTask |= launchedTaskAtCurrentMaxLocality
-        } while (launchedTaskAtCurrentMaxLocality)
-      }
-      if (!launchedAnyTask) {
-        taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
+      // Skip the barrier taskSet if the available slots are less than the number of pending tasks.
+      if (taskSet.isBarrier && availableSlots < taskSet.numTasks) {
+        // Skip the launch process.
+        // TODO SPARK-24819 If the job requires more slots than available (both busy and free
+        // slots), fail the job on submit.
+        logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " +
+          s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " +
+          s"number of available slots is $availableSlots.")
+      } else {
+        var launchedAnyTask = false
+        // Record all the executor IDs assigned barrier tasks on.
+        val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]()
+        for (currentMaxLocality <- taskSet.myLocalityLevels) {
+          var launchedTaskAtCurrentMaxLocality = false
+          do {
+            launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet,
+              currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs)
+            launchedAnyTask |= launchedTaskAtCurrentMaxLocality
+          } while (launchedTaskAtCurrentMaxLocality)
+        }
+        if (!launchedAnyTask) {
+          taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
+        }
+        if (launchedAnyTask && taskSet.isBarrier) {
+          // Check whether the barrier tasks are partially launched.
+          // TODO SPARK-24818 handle the assert failure case (that can happen when some locality
+          // requirements are not fulfilled, and we should revert the launched tasks).
+          require(addressesWithDescs.size == taskSet.numTasks,
+            s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " +
+              s"because only ${addressesWithDescs.size} out of a total number of " +
+              s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " +
+              "been blacklisted or cannot fulfill task locality requirements.")
+
+          // Update the taskInfos into all the barrier task properties.
+          val addressesStr = addressesWithDescs
+            // Addresses ordered by partitionId
+            .sortBy(_._2.partitionId)
+            .map(_._1)
+            .mkString(",")
+          addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr))
+
+          logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " +
+            s"stage ${taskSet.stageId}.")
+        }
       }
     }
 
+    // TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get
+    // launched within a configured time.
     if (tasks.size > 0) {
       hasLaunchedTask = true
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index defed1e..0b21256 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -123,6 +123,10 @@ private[spark] class TaskSetManager(
   // TODO: We should kill any running task attempts when the task set manager becomes a zombie.
   private[scheduler] var isZombie = false
 
+  // Whether the taskSet run tasks from a barrier stage. Spark must launch all the tasks at the
+  // same time for a barrier stage.
+  private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier
+
   // Set of pending tasks for each executor. These collections are actually
   // treated as stacks, in which new tasks are added to the end of the
   // ArrayBuffer and removed from the end. This makes it faster to detect
@@ -512,6 +516,7 @@ private[spark] class TaskSetManager(
           execId,
           taskName,
           index,
+          task.partitionId,
           addedFiles,
           addedJars,
           task.localProperties,
@@ -979,8 +984,8 @@ private[spark] class TaskSetManager(
    */
   override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
     // Can't speculate if we only have one task, and no need to speculate if the task set is a
-    // zombie.
-    if (isZombie || numTasks == 1) {
+    // zombie or is from a barrier stage.
+    if (isZombie || isBarrier || numTasks == 1) {
       return false
     }
     var foundTasks = false

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index 810b36c..6ec7491 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -21,4 +21,10 @@ package org.apache.spark.scheduler
  * Represents free resources available on an executor.
  */
 private[spark]
-case class WorkerOffer(executorId: String, host: String, cores: Int)
+case class WorkerOffer(
+    executorId: String,
+    host: String,
+    cores: Int,
+    // `address` is an optional hostPort string, it provide more useful information than `host`
+    // when multiple executors are launched on the same host.
+    address: Option[String] = None)

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 9b90e30..375aeb0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -242,7 +242,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
         val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
         val workOffers = activeExecutors.map {
           case (id, executorData) =>
-            new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
+            new WorkerOffer(id, executorData.executorHost, executorData.freeCores,
+              Some(executorData.executorAddress.hostPort))
         }.toIndexedSeq
         scheduler.resourceOffers(workOffers)
       }
@@ -267,7 +268,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
         if (executorIsAlive(executorId)) {
           val executorData = executorDataMap(executorId)
           val workOffers = IndexedSeq(
-            new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+            new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores,
+              Some(executorData.executorAddress.hostPort)))
           scheduler.resourceOffers(workOffers)
         } else {
           Seq.empty

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
index 4c614c5..cf8b0ff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
@@ -81,7 +81,8 @@ private[spark] class LocalEndpoint(
   }
 
   def reviveOffers() {
-    val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+    val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores,
+      Some(rpcEnv.address.hostPort)))
     for (task <- scheduler.resourceOffers(offers).flatten) {
       freeCores -= scheduler.CPUS_PER_TASK
       executor.launchTask(executorBackend, task)

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index ce9f2be..e5f31a0 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -627,6 +627,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
     assert(exc.getCause() != null)
     stream.close()
   }
+
+  test("support barrier execution mode under local mode") {
+    val conf = new SparkConf().setAppName("test").setMaster("local[2]")
+    sc = new SparkContext(conf)
+    val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
+    val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
+      // If we don't get the expected taskInfos, the job shall abort due to stage failure.
+      if (context.getTaskInfos().length != 2) {
+        throw new SparkException("Expected taksInfos length is 2, actual length is " +
+          s"${context.getTaskInfos().length}.")
+      }
+      context.barrier()
+      it
+    }
+    rdd2.collect()
+
+    eventually(timeout(10.seconds)) {
+      assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
+    }
+  }
+
+  test("support barrier execution mode under local-cluster mode") {
+    val conf = new SparkConf()
+      .setMaster("local-cluster[3, 1, 1024]")
+      .setAppName("test-cluster")
+    sc = new SparkContext(conf)
+    val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
+    val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
+      // If we don't get the expected taskInfos, the job shall abort due to stage failure.
+      if (context.getTaskInfos().length != 2) {
+        throw new SparkException("Expected taksInfos length is 2, actual length is " +
+          s"${context.getTaskInfos().length}.")
+      }
+      context.barrier()
+      it
+    }
+    rdd2.collect()
+
+    eventually(timeout(10.seconds)) {
+      assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
+    }
+  }
 }
 
 object SparkContextSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 1a7bebe..77a7668 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
       executorId = "",
       name = "",
       index = 0,
+      partitionId = 0,
       addedFiles = Map[String, Long](),
       addedJars = Map[String, Long](),
       properties = new Properties,

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
new file mode 100644
index 0000000..39d4618
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+
+class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext {
+
+  test("create an RDDBarrier") {
+    val rdd = sc.parallelize(1 to 10, 4)
+    assert(rdd.isBarrier() === false)
+
+    val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter)
+    assert(rdd2.isBarrier() === true)
+  }
+
+  test("create an RDDBarrier in the middle of a chain of RDDs") {
+    val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2)
+    val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1))
+    assert(rdd2.isBarrier() === true)
+  }
+
+  test("RDDBarrier with shuffle") {
+    val rdd = sc.parallelize(1 to 10, 4)
+    val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2)
+    assert(rdd2.isBarrier() === false)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2987170..b3db5e2 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1055,6 +1055,64 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
     assert(sparkListener.failedStages.size == 1)
   }
 
+  test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+    submit(reduceRdd, Array(0, 1))
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
+      (Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+
+    // The first result task fails, with a fetch failure for the output from the first mapper.
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1)))
+
+    scheduler.resubmitFailedStages()
+    // Complete the map stage.
+    completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2)
+
+    // Complete the result stage.
+    completeNextResultStageWithSuccess(1, 1)
+
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assertDataStructuresEmpty()
+  }
+
+  test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+    submit(reduceRdd, Array(0, 1))
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", reduceRdd.partitions.length))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(1)))
+
+    // The second map task fails with TaskKilled.
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1),
+      TaskKilled("test"),
+      null))
+    assert(sparkListener.failedStages === Seq(0))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1)))
+
+    scheduler.resubmitFailedStages()
+    // Complete the map stage.
+    completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2)
+
+    // Complete the result stage.
+    completeNextResultStageWithSuccess(1, 0)
+
+    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+    assertDataStructuresEmpty()
+  }
+
   /**
    * This tests the case where another FetchFailed comes in while the map stage is getting
    * re-run.

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 109d4a0..b29d32f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -27,8 +27,10 @@ class FakeTask(
     partitionId: Int,
     prefLocs: Seq[TaskLocation] = Nil,
     serializedTaskMetrics: Array[Byte] =
-      SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array())
-  extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) {
+      SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
+    isBarrier: Boolean = false)
+  extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics,
+    isBarrier = isBarrier) {
 
   override def runTask(context: TaskContext): Int = 0
   override def preferredLocations: Seq[TaskLocation] = prefLocs
@@ -74,4 +76,22 @@ object FakeTask {
     }
     new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
   }
+
+  def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+    createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*)
+  }
+
+  def createBarrierTaskSet(
+      numTasks: Int,
+      stageId: Int,
+      stageAttempId: Int,
+      prefLocs: Seq[TaskLocation]*): TaskSet = {
+    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
+      throw new IllegalArgumentException("Wrong number of task locations")
+    }
+    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
+      new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true)
+    }
+    new TaskSet(tasks, stageId, stageAttempId, priority = 0, null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
index 97487ce..ba62eec 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
@@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
       executorId = "testExecutor",
       name = "task for test",
       index = 19,
+      partitionId = 1,
       originalFiles,
       originalJars,
       originalProperties,
@@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
     assert(decodedTaskDescription.executorId === originalTaskDescription.executorId)
     assert(decodedTaskDescription.name === originalTaskDescription.name)
     assert(decodedTaskDescription.index === originalTaskDescription.index)
+    assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId)
     assert(decodedTaskDescription.addedFiles.equals(originalFiles))
     assert(decodedTaskDescription.addedJars.equals(originalJars))
     assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties))

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 33f2ea1..624384a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -1021,4 +1021,38 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
       verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
     }
   }
+
+  test("don't schedule for a barrier taskSet if available slots are less than pending tasks") {
+    val taskCpus = 2
+    val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString)
+
+    val numFreeCores = 3
+    val workerOffers = IndexedSeq(
+      new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")),
+      new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")))
+    val attempt1 = FakeTask.createBarrierTaskSet(3)
+
+    // submit attempt 1, offer some resources, since the available slots are less than pending
+    // tasks, don't schedule barrier tasks on the resource offer.
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(0 === taskDescriptions.length)
+  }
+
+  test("schedule tasks for a barrier taskSet if all tasks can be launched together") {
+    val taskCpus = 2
+    val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString)
+
+    val numFreeCores = 3
+    val workerOffers = IndexedSeq(
+      new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")),
+      new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")),
+      new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629")))
+    val attempt1 = FakeTask.createBarrierTaskSet(3)
+
+    // submit attempt 1, offer some resources, all tasks get launched together
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(3 === taskDescriptions.length)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3486e1b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
----------------------------------------------------------------------
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
index 2d2f90c..31f8431 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
@@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite
       executorId = "s1",
       name = "n1",
       index = 0,
+      partitionId = 0,
       addedFiles = mutable.Map.empty[String, Long],
       addedJars = mutable.Map.empty[String, Long],
       properties = new Properties(),
@@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite
       executorId = "s1",
       name = "n1",
       index = 0,
+      partitionId = 0,
       addedFiles = mutable.Map.empty[String, Long],
       addedJars = mutable.Map.empty[String, Long],
       properties = new Properties(),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org