You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/15 00:33:56 UTC

[GitHub] vanzin closed pull request #19267: [WIP][SPARK-20628][CORE] Blacklist nodes when they transition to DECOMMISSIONING state in YARN

vanzin closed pull request #19267: [WIP][SPARK-20628][CORE] Blacklist nodes when they transition to DECOMMISSIONING state in YARN
URL: https://github.com/apache/spark/pull/19267
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/core/src/main/scala/org/apache/spark/HostState.scala b/core/src/main/scala/org/apache/spark/HostState.scala
new file mode 100644
index 0000000000000..17b374c3fac26
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/HostState.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.hadoop.yarn.api.records.NodeState
+
+private[spark] object HostState extends Enumeration {
+
+  type HostState = Value
+
+  val New, Running, Unhealthy, Decommissioning, Decommissioned, Lost, Rebooted = Value
+
+  def fromYarnState(state: String): Option[HostState] = {
+    HostState.values.find(_.toString.toUpperCase == state)
+  }
+
+  def toYarnState(state: HostState): Option[String] = {
+    NodeState.values.find(_.name == state.toString.toUpperCase).map(_.name)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9495cd2835f97..84edcff707d44 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -154,6 +154,16 @@ package object config {
     ConfigBuilder("spark.blacklist.application.fetchFailure.enabled")
       .booleanConf
       .createWithDefault(false)
+
+  private[spark] val BLACKLIST_DECOMMISSIONING_ENABLED =
+    ConfigBuilder("spark.blacklist.decommissioning.enabled")
+      .booleanConf
+      .createWithDefault(false)
+
+  private[spark] val BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF =
+    ConfigBuilder("spark.blacklist.decommissioning.timeout")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createOptional
   // End blacklist confs
 
   private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
index cd8e61d6d0208..7bc3db8ce1bb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
@@ -61,7 +61,13 @@ private[scheduler] class BlacklistTracker (
   private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC)
   private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE)
   val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf)
-  private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED)
+  val BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS =
+    BlacklistTracker.getBlacklistDecommissioningTimeout(conf)
+  private val TASK_BLACKLISTING_ENABLED = BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)
+  private val DECOMMISSIONING_BLACKLISTING_ENABLED =
+    BlacklistTracker.isDecommissioningBlacklistingEnabled(conf)
+  private val BLACKLIST_FETCH_FAILURE_ENABLED =
+    BlacklistTracker.isFetchFailureBlacklistingEnabled(conf)
 
   /**
    * A map from executorId to information on task failures.  Tracks the time of each task failure,
@@ -89,13 +95,13 @@ private[scheduler] class BlacklistTracker (
    * successive blacklisted executors on one node.  Nonetheless, it will not grow too large because
    * there cannot be many blacklisted executors on one node, before we stop requesting more
    * executors on that node, and we clean up the list of blacklisted executors once an executor has
-   * been blacklisted for BLACKLIST_TIMEOUT_MILLIS.
+   * been blacklisted for its configured blacklisting timeout.
    */
   val nodeToBlacklistedExecs = new HashMap[String, HashSet[String]]()
 
   /**
-   * Un-blacklists executors and nodes that have been blacklisted for at least
-   * BLACKLIST_TIMEOUT_MILLIS
+   * Un-blacklists executors and nodes that have been blacklisted for at least its configured
+   * blacklisting timeout
    */
   def applyBlacklistTimeout(): Unit = {
     val now = clock.getTimeMillis()
@@ -118,16 +124,9 @@ private[scheduler] class BlacklistTracker (
         }
       }
       val nodesToUnblacklist = nodeIdToBlacklistExpiryTime.filter(_._2 < now).keys
-      if (nodesToUnblacklist.nonEmpty) {
-        // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout.
-        logInfo(s"Removing nodes $nodesToUnblacklist from blacklist because the blacklist " +
-          s"has timed out")
-        nodesToUnblacklist.foreach { node =>
-          nodeIdToBlacklistExpiryTime.remove(node)
-          listenerBus.post(SparkListenerNodeUnblacklisted(now, node))
-        }
-        _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
-      }
+          .map(node => (node, BlacklistTimedOut, Some(now)))
+      // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout.
+      removeNodesFromBlacklist(nodesToUnblacklist)
       updateNextExpiryTime()
     }
   }
@@ -190,14 +189,8 @@ private[scheduler] class BlacklistTracker (
       val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS
 
       if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) {
-        if (!nodeIdToBlacklistExpiryTime.contains(host)) {
-          logInfo(s"blacklisting node $host due to fetch failure of external shuffle service")
-
-          nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists)
-          listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1))
-          _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
+        if (addNodeToBlacklist(host, FetchFailure(host), now)) {
           killExecutorsOnBlacklistedNode(host)
-          updateNextExpiryTime()
         }
       } else if (!executorIdToBlacklistStatus.contains(exec)) {
         logInfo(s"Blacklisting executor $exec due to fetch failure")
@@ -249,21 +242,93 @@ private[scheduler] class BlacklistTracker (
         // node, and potentially put the entire node into a blacklist as well.
         val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]())
         blacklistedExecsOnNode += exec
-        // If the node is already in the blacklist, we avoid adding it again with a later expiry
-        // time.
-        if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE &&
-            !nodeIdToBlacklistExpiryTime.contains(node)) {
-          logInfo(s"Blacklisting node $node because it has ${blacklistedExecsOnNode.size} " +
-            s"executors blacklisted: ${blacklistedExecsOnNode}")
-          nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists)
-          listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size))
-          _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
-          killExecutorsOnBlacklistedNode(node)
+        if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE) {
+          val blacklistSucceeded = addNodeToBlacklist(node,
+            ExecutorFailures(Set(blacklistedExecsOnNode.toList: _*)), now)
+          if (blacklistSucceeded) {
+            killExecutorsOnBlacklistedNode(node)
+          }
         }
       }
     }
   }
 
+  /**
+   * Add nodes to Blacklist, with a specific timeout depending upon the reason. If the node is
+   * already in the Blacklist, it is not added again.
+   * @param node Node to be blacklisted
+   * @param reason Reason for blacklisting the node
+   * @param time Optional start time on which to compute the blacklist expiry time
+   * @return boolean value indicating whether node was added to blacklist or not
+   */
+  def addNodeToBlacklist(node: String, reason: NodeBlacklistReason,
+                         time: Long = clock.getTimeMillis()): Boolean = {
+    // If the node is already in the blacklist, we avoid adding it again with a later expiry time.
+    if (!isNodeBlacklisted(node)) {
+      val blacklistExpiryTimeOpt = reason match {
+        case NodeDecommissioning if DECOMMISSIONING_BLACKLISTING_ENABLED =>
+          val expiryTime = time + BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS
+          logInfo(s"Blacklisting node $node with timeout $expiryTime ms because ${reason.message}")
+          Some(expiryTime)
+
+        case ExecutorFailures(blacklistedExecutors) if TASK_BLACKLISTING_ENABLED =>
+          val expiryTime = time + BLACKLIST_TIMEOUT_MILLIS
+          logInfo(s"Blacklisting node $node with timeout $expiryTime ms because it " +
+            s"has ${blacklistedExecutors.size} executors blacklisted: ${blacklistedExecutors}")
+          Some(expiryTime)
+
+        case FetchFailure(host) if BLACKLIST_FETCH_FAILURE_ENABLED =>
+          val expiryTime = time + BLACKLIST_TIMEOUT_MILLIS
+          logInfo(s"Blacklisting node $host due to fetch failure of external shuffle service")
+          Some(expiryTime)
+
+        case _ => None
+      }
+
+      blacklistExpiryTimeOpt.fold(false) { blacklistExpiryTime =>
+        blacklistNodeHelper(node, blacklistExpiryTime)
+        listenerBus.post(SparkListenerNodeBlacklisted(time, node, reason))
+        updateNextExpiryTime()
+        true
+      }
+    }
+    else {
+      false
+    }
+  }
+
+  private def blacklistNodeHelper(node: String, blacklistExpiryTimeout: Long): Unit = {
+    nodeIdToBlacklistExpiryTime.put(node, blacklistExpiryTimeout)
+    _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
+  }
+
+  private def unblacklistNodesHelper(nodes: Iterable[String]): Unit = {
+    nodeIdToBlacklistExpiryTime --= nodes
+    _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
+  }
+
+  /**
+   * @param nodesToRemove List of nodes to unblacklist, with there reason for unblacklisting
+   *                      and an optional time to be passed to Spark Listener indicating the
+   *                      time of unblacklist.
+   */
+  def removeNodesFromBlacklist(nodesToRemove: Iterable[(String, NodeUnblacklistReason,
+    Option[Long])]): Unit = {
+    if (nodesToRemove.nonEmpty) {
+      val blacklistNodesToRemove = nodesToRemove.filter{ case (node, reason, _) =>
+        (reason == BlacklistTimedOut ||
+          (reason ==  NodeRunning && DECOMMISSIONING_BLACKLISTING_ENABLED)) &&
+          isNodeBlacklisted(node)
+      }
+      unblacklistNodesHelper(blacklistNodesToRemove.map(_._1))
+      blacklistNodesToRemove.foreach(node => {
+        logInfo(s"Removing node $node from blacklist because ${node._2.message}")
+        listenerBus.post(SparkListenerNodeUnblacklisted(
+          node._3.getOrElse(clock.getTimeMillis()), node._1, node._2))
+      })
+    }
+  }
+
   def isExecutorBlacklisted(executorId: String): Boolean = {
     executorIdToBlacklistStatus.contains(executorId)
   }
@@ -373,15 +438,39 @@ private[scheduler] class BlacklistTracker (
 private[scheduler] object BlacklistTracker extends Logging {
 
   private val DEFAULT_TIMEOUT = "1h"
+  private val DEFAULT_DECOMMISSIONING_TIMEOUT = "1h"
 
   /**
-   * Returns true if the blacklist is enabled, based on checking the configuration in the following
-   * order:
+   * Returns true if the task execution blacklist, fetch failure blacklist,
+   * or decommission blacklisting are enabled
+   */
+  def isBlacklistEnabled(conf: SparkConf): Boolean = {
+    isFetchFailureBlacklistingEnabled(conf) || isDecommissioningBlacklistingEnabled(conf) ||
+      isTaskExecutionBlacklistingEnabled(conf)
+  }
+
+  /**
+   * Returns true if the fetch failure blacklisting is enabled
+   */
+  def isFetchFailureBlacklistingEnabled(conf: SparkConf): Boolean = {
+    conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED)
+  }
+
+  /**
+   * Returns true if the decommission blacklisting is enabled
+   */
+  def isDecommissioningBlacklistingEnabled(conf: SparkConf): Boolean = {
+    conf.get(config.BLACKLIST_DECOMMISSIONING_ENABLED)
+  }
+
+  /**
+   * Returns true if the task execution blacklist is enabled, based on checking the configuration
+   * in the following order:
    * 1. Is it specifically enabled or disabled?
    * 2. Is it enabled via the legacy timeout conf?
    * 3. Default is off
    */
-  def isBlacklistEnabled(conf: SparkConf): Boolean = {
+  def isTaskExecutionBlacklistingEnabled(conf: SparkConf): Boolean = {
     conf.get(config.BLACKLIST_ENABLED) match {
       case Some(enabled) =>
         enabled
@@ -409,6 +498,11 @@ private[scheduler] object BlacklistTracker extends Logging {
     }
   }
 
+  def getBlacklistDecommissioningTimeout(conf: SparkConf): Long = {
+    conf.get(config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF)
+      .getOrElse(Utils.timeStringAsMs(DEFAULT_DECOMMISSIONING_TIMEOUT))
+  }
+
   /**
    * Verify that blacklist configurations are consistent; if not, throw an exception.  Should only
    * be called if blacklisting is enabled.
@@ -449,6 +543,12 @@ private[scheduler] object BlacklistTracker extends Logging {
       }
     }
 
+    val blacklistDecommissioningTimeout = getBlacklistDecommissioningTimeout(conf)
+    if (blacklistDecommissioningTimeout <= 0) {
+      mustBePos(config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF.key,
+        blacklistDecommissioningTimeout.toString)
+    }
+
     val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES)
     val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE)
 
@@ -458,7 +558,9 @@ private[scheduler] object BlacklistTracker extends Logging {
         s"( = ${maxTaskFailures} ).  Though blacklisting is enabled, with this configuration, " +
         s"Spark will not be robust to one bad node.  Decrease " +
         s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " +
-        s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}")
+        s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}, " +
+        s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key} " +
+        s"and ${config.BLACKLIST_FETCH_FAILURE_ENABLED.key}")
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala b/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala
new file mode 100644
index 0000000000000..53fba245e0ee5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Represents an explanation for a Node being blacklisted for task scheduling
+ */
+@DeveloperApi
+private[spark] sealed trait NodeBlacklistReason extends Serializable {
+  def message: String
+}
+
+@DeveloperApi
+private[spark] case class ExecutorFailures(blacklistedExecutors: Set[String])
+  extends NodeBlacklistReason {
+  override def message: String = "Maximum number of executor failures allowed on Node exceeded."
+}
+
+@DeveloperApi
+private[spark] case object NodeDecommissioning extends NodeBlacklistReason {
+  override def message: String = "Node is being decommissioned by Cluster Manager."
+}
+
+@DeveloperApi
+private[spark] case class FetchFailure(host: String) extends NodeBlacklistReason {
+  override def message: String = s"Fetch failure for host $host"
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala b/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala
new file mode 100644
index 0000000000000..b388ff997fa32
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Represents an explanation for a Node being unblacklisted for task scheduling.
+ */
+@DeveloperApi
+private[spark] sealed trait NodeUnblacklistReason extends Serializable {
+  def message: String
+}
+
+@DeveloperApi
+private[spark] object BlacklistTimedOut extends NodeUnblacklistReason {
+  override def message: String = "Blacklist timeout has reached."
+}
+
+@DeveloperApi
+private[spark] object NodeRunning extends NodeUnblacklistReason {
+  override def message: String = "Node is active and back to Running state."
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 59f89a82a1da8..5fe9cd76de38c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -125,11 +125,11 @@ case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String)
 case class SparkListenerNodeBlacklisted(
     time: Long,
     hostId: String,
-    executorFailures: Int)
+    reason: NodeBlacklistReason)
   extends SparkListenerEvent
 
 @DeveloperApi
-case class SparkListenerNodeUnblacklisted(time: Long, hostId: String)
+case class SparkListenerNodeUnblacklisted(time: Long, hostId: String, reason: NodeUnblacklistReason)
   extends SparkListenerEvent
 
 @DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 737b383631148..8c23f7cd7769e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -656,6 +656,14 @@ private[spark] class TaskSchedulerImpl(
     blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(scala.collection.immutable.Set())
   }
 
+  def blacklistExecutorsOnHost(host: String, reason: NodeBlacklistReason): Unit = synchronized {
+    blacklistTrackerOpt.foreach(_.addNodeToBlacklist(host, reason))
+  }
+
+  def unblacklistExecutorsOnHost(host: String, reason: NodeUnblacklistReason): Unit = synchronized {
+    blacklistTrackerOpt.foreach(_.removeNodesFromBlacklist(List((host, reason, None))))
+  }
+
   // By default, rack is unknown
   def getRackForHost(value: String): Option[String] = None
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index c2f817858473c..0e75032b3de5d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -96,8 +96,10 @@ private[spark] class TaskSetManager(
   private var calculatedTasks = 0
 
   private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = {
-    blacklistTracker.map { _ =>
-      new TaskSetBlacklist(conf, stageId, clock)
+    if (blacklistTracker.isDefined && BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) {
+      Some(new TaskSetBlacklist(conf, stageId, clock))
+    } else {
+      None
     }
   }
 
@@ -519,7 +521,7 @@ private[spark] class TaskSetManager(
   private def maybeFinishTaskSet() {
     if (isZombie && runningTasks == 0) {
       sched.taskSetFinished(this)
-      if (tasksSuccessful == numTasks) {
+      if (taskSetBlacklistHelperOpt.isDefined && tasksSuccessful == numTasks) {
         blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
           taskSet.stageId,
           taskSet.stageAttemptId,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 5d65731dfc30e..79fabd380288f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster
 
 import java.nio.ByteBuffer
 
+import org.apache.spark.HostState.HostState
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.scheduler.ExecutorLossReason
@@ -110,6 +111,9 @@ private[spark] object CoarseGrainedClusterMessages {
       nodeBlacklist: Set[String])
     extends CoarseGrainedClusterMessage
 
+  case class HostStatusUpdate(host: String, hostState: HostState)
+    extends CoarseGrainedClusterMessage
+
   // Check if an executor was force-killed but for a reason unrelated to the running tasks.
   // This could be the case if the executor is preempted, for instance.
   case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index a0ef209779309..d8857a19f9c8f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -26,7 +26,7 @@ import scala.concurrent.Future
 
 import org.apache.hadoop.security.UserGroupInformation
 
-import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState}
+import org.apache.spark.{ExecutorAllocationClient, HostState, SparkEnv, SparkException, TaskState}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.deploy.security.HadoopDelegationTokenManager
 import org.apache.spark.internal.Logging
@@ -482,6 +482,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
     }(ThreadUtils.sameThread)
   }
 
+  private[scheduler] def handleUpdatedHostState(host: String,
+                                                hostState: HostState.HostState): Unit = {
+    hostState match {
+      case HostState.Decommissioning =>
+        scheduler.blacklistExecutorsOnHost(host, NodeDecommissioning)
+
+      case HostState.Running =>
+        scheduler.unblacklistExecutorsOnHost(host, NodeRunning)
+
+      case HostState.Decommissioned | HostState.Lost =>
+        // TODO: Take action when a node is Decommissioned or Lost
+
+      case _ =>
+    }
+  }
+
   def sufficientResourcesRegistered(): Boolean = true
 
   override def isReady(): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 8406826a228db..9925ad3c18abc 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -100,6 +100,10 @@ private[spark] object JsonProtocol {
         executorMetricsUpdateToJson(metricsUpdate)
       case blockUpdated: SparkListenerBlockUpdated =>
         throw new MatchError(blockUpdated)  // TODO(ekl) implement this
+      case nodeBlacklisted: SparkListenerNodeBlacklisted =>
+        nodeBlacklistedToJson(nodeBlacklisted)
+      case nodeUnblacklisted: SparkListenerNodeUnblacklisted =>
+        nodeUnblacklistedToJson(nodeUnblacklisted)
       case _ => parse(mapper.writeValueAsString(event))
     }
   }
@@ -246,6 +250,20 @@ private[spark] object JsonProtocol {
     })
   }
 
+  def nodeBlacklistedToJson(nodeBlacklisted: SparkListenerNodeBlacklisted): JValue = {
+    ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.nodeBlacklisted) ~
+    ("hostId" -> nodeBlacklisted.hostId) ~
+    ("time" -> nodeBlacklisted.time) ~
+    ("blacklistReason" -> nodeBlacklistReasonToJson(nodeBlacklisted.reason))
+  }
+
+  def nodeUnblacklistedToJson(nodeUnblacklisted: SparkListenerNodeUnblacklisted): JValue = {
+    ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.nodeUnblacklisted) ~
+    ("hostId" -> nodeUnblacklisted.hostId) ~
+    ("time" -> nodeUnblacklisted.time) ~
+    ("unblacklistReason" -> nodeUnblacklistReasonToJson(nodeUnblacklisted.reason))
+  }
+
   /** ------------------------------------------------------------------- *
    * JSON serialization methods for classes SparkListenerEvents depend on |
    * -------------------------------------------------------------------- */
@@ -407,6 +425,24 @@ private[spark] object JsonProtocol {
     ("Reason" -> reason) ~ json
   }
 
+  def nodeBlacklistReasonToJson(nodeBlacklistReason: NodeBlacklistReason): JValue = {
+    val reason = Utils.getFormattedClassName(nodeBlacklistReason)
+    val json: JObject = nodeBlacklistReason match {
+      case ExecutorFailures(blacklistedExecutors) =>
+        ("blacklistedExecutors" -> blacklistedExecutors)
+      case NodeDecommissioning =>
+        Utils.emptyJson
+      case FetchFailure(host) =>
+        ("host" -> host)
+    }
+    ("reason" -> reason) ~ json
+  }
+
+  def nodeUnblacklistReasonToJson(nodeUnblacklistReason: NodeUnblacklistReason): JValue = {
+    val reason = Utils.getFormattedClassName(nodeUnblacklistReason)
+    "reason" -> reason
+  }
+
   def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = {
     ("Executor ID" -> blockManagerId.executorId) ~
     ("Host" -> blockManagerId.host) ~
@@ -515,6 +551,8 @@ private[spark] object JsonProtocol {
     val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
     val logStart = Utils.getFormattedClassName(SparkListenerLogStart)
     val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate)
+    val nodeBlacklisted = Utils.getFormattedClassName(SparkListenerNodeBlacklisted)
+    val nodeUnblacklisted = Utils.getFormattedClassName(SparkListenerNodeUnblacklisted)
   }
 
   def sparkEventFromJson(json: JValue): SparkListenerEvent = {
@@ -538,6 +576,8 @@ private[spark] object JsonProtocol {
       case `executorRemoved` => executorRemovedFromJson(json)
       case `logStart` => logStartFromJson(json)
       case `metricsUpdate` => executorMetricsUpdateFromJson(json)
+      case `nodeBlacklisted` => nodeBlacklistedFromJson(json)
+      case `nodeUnblacklisted` => nodeUnBlacklistedFromJson(json)
       case other => mapper.readValue(compact(render(json)), Utils.classForName(other))
         .asInstanceOf[SparkListenerEvent]
     }
@@ -676,6 +716,20 @@ private[spark] object JsonProtocol {
     SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates)
   }
 
+  def nodeBlacklistedFromJson(json: JValue): SparkListenerNodeBlacklisted = {
+    val host = (json \ "hostId").extract[String]
+    val time = (json \ "time").extract[Long]
+    val reason = nodeBlacklistReasonFromJson(json \ "blacklistReason")
+    SparkListenerNodeBlacklisted(time, host, reason)
+  }
+
+  def nodeUnBlacklistedFromJson(json: JValue): SparkListenerNodeUnblacklisted = {
+    val host = (json \ "hostId").extract[String]
+    val time = (json \ "time").extract[Long]
+    val reason = nodeUnblacklistReasonFromJson(json \ "unblacklistReason")
+    SparkListenerNodeUnblacklisted(time, host, reason)
+  }
+
   /** --------------------------------------------------------------------- *
    * JSON deserialization methods for classes SparkListenerEvents depend on |
    * ---------------------------------------------------------------------- */
@@ -917,6 +971,42 @@ private[spark] object JsonProtocol {
     }
   }
 
+  private object NODE_BLACKLIST_REASON_FORMATTED_CLASS_NAMES {
+    val executorFailures = Utils.getFormattedClassName(ExecutorFailures)
+    val nodeDecommissioning = Utils.getFormattedClassName(NodeDecommissioning)
+    val fetchFailure = Utils.getFormattedClassName(FetchFailure)
+  }
+
+  def nodeBlacklistReasonFromJson(json: JValue): NodeBlacklistReason = {
+    import NODE_BLACKLIST_REASON_FORMATTED_CLASS_NAMES._
+
+    (json \ "reason").extract[String] match {
+      case `executorFailures` =>
+        val blacklistedExecutors = (json \ "blacklistedExecutors").extract[List[String]]
+        new ExecutorFailures(Set(blacklistedExecutors: _*))
+
+      case `nodeDecommissioning` => NodeDecommissioning
+
+      case `fetchFailure` =>
+        val host = (json \ "host").extract[String]
+        new FetchFailure(host)
+    }
+  }
+
+  private object NODE_UNBLACKLIST_REASON_FORMATTED_CLASS_NAMES {
+    val blacklistTimedOut = Utils.getFormattedClassName(BlacklistTimedOut)
+    val nodeRunning = Utils.getFormattedClassName(NodeRunning)
+  }
+
+  def nodeUnblacklistReasonFromJson(json: JValue): NodeUnblacklistReason = {
+    import NODE_UNBLACKLIST_REASON_FORMATTED_CLASS_NAMES._
+
+    (json \ "reason").extract[String] match {
+      case `blacklistTimedOut` => BlacklistTimedOut
+      case `nodeRunning` => NodeRunning
+    }
+  }
+
   def blockManagerIdFromJson(json: JValue): BlockManagerId = {
     // On metadata fetch fail, block manager ID can be null (SPARK-4471)
     if (json == JNothing) {
diff --git a/core/src/test/resources/spark-events/app-20161115172038-0000 b/core/src/test/resources/spark-events/app-20161115172038-0000
index 3af0451d0c392..f0e919b364b67 100755
--- a/core/src/test/resources/spark-events/app-20161115172038-0000
+++ b/core/src/test/resources/spark-events/app-20161115172038-0000
@@ -68,8 +68,8 @@
 {"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479252044931,"Job Result":{"Result":"JobSucceeded"}}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"2","taskFailures":4}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"0","taskFailures":4}
-{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479252044930,"hostId":"172.22.0.111","executorFailures":2}
+{"Event":"SparkListenerNodeBlacklisted","time":1479252044930,"hostId":"172.22.0.111","blacklistReason":{"reason":"ExecutorFailures","blacklistedExecutors":["exec1","exec2","exec3"]}}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"2"}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"0"}
-{"Event":"org.apache.spark.scheduler.SparkListenerNodeUnblacklisted","time":1479252055635,"hostId":"172.22.0.111"}
+{"Event":"SparkListenerNodeUnblacklisted","time":1479252055635,"hostId":"172.22.0.111","unblacklistReason":{"reason":"BlacklistTimedOut"}}
 {"Event":"SparkListenerApplicationEnd","Timestamp":1479252138874}
diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000
index 57cfc5b973129..44b6554c6ce63 100755
--- a/core/src/test/resources/spark-events/app-20161116163331-0000
+++ b/core/src/test/resources/spark-events/app-20161116163331-0000
@@ -64,5 +64,5 @@
 {"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479335617480,"Job Result":{"Result":"JobSucceeded"}}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"2","taskFailures":4}
 {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"0","taskFailures":4}
-{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479335617478,"hostId":"172.22.0.167","executorFailures":2}
+{"Event":"SparkListenerNodeBlacklisted","time":1479335617478,"hostId":"172.22.0.167","blacklistReason":{"reason":"ExecutorFailures","blacklistedExecutors":["exec1","exec2","exec3"]}}
 {"Event":"SparkListenerApplicationEnd","Timestamp":1479335620587}
diff --git a/core/src/test/scala/org/apache/spark/HostStateSuite.scala b/core/src/test/scala/org/apache/spark/HostStateSuite.scala
new file mode 100644
index 0000000000000..95862ce127e72
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/HostStateSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.Matchers
+import org.scalatest.prop.TableDrivenPropertyChecks.{forAll, Table}
+
+import org.apache.spark.HostState.HostState
+
+class HostStateSuite extends SparkFunSuite with Matchers {
+
+  test("Contract for the conversion between YARN NodeState and HostState") {
+    val mappings =
+      Table(
+        ("yarnNodeState", "hostState"),
+        (HostState.toYarnState(HostState.New), HostState.New),
+        (HostState.toYarnState(HostState.Running), HostState.Running),
+        (HostState.toYarnState(HostState.Decommissioned), HostState.Decommissioned),
+        (HostState.toYarnState(HostState.Decommissioning), HostState.Decommissioning),
+        (HostState.toYarnState(HostState.Unhealthy), HostState.Unhealthy),
+        (HostState.toYarnState(HostState.Rebooted), HostState.Rebooted))
+
+    forAll (mappings) { (yarnNodeState: Option[String], hostState: HostState) =>
+      assert(yarnNodeState.isDefined)
+      val hostStateOpt = HostState.fromYarnState(yarnNodeState.get)
+      assert(hostStateOpt.isDefined)
+      hostStateOpt.get should be (hostState)
+    }
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
index f6015cd51c2bd..07136606f545e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
@@ -44,7 +44,8 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
   // according to locality preferences, and so the job fails
   testScheduler("If preferred node is bad, without blacklist job will fail",
     extraConfs = Seq(
-      config.BLACKLIST_ENABLED.key -> "false"
+      config.BLACKLIST_ENABLED.key -> "false",
+      config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "true"
   )) {
     val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost)
     withBackend(badHostBackend _) {
@@ -58,6 +59,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
     "With default settings, job can succeed despite multiple bad executors on node",
     extraConfs = Seq(
       config.BLACKLIST_ENABLED.key -> "true",
+      config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false",
       config.MAX_TASK_FAILURES.key -> "4",
       "spark.testing.nHosts" -> "2",
       "spark.testing.nExecutorsPerHost" -> "5",
@@ -84,6 +86,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
     "Bad node with multiple executors, job will still succeed with the right confs",
     extraConfs = Seq(
        config.BLACKLIST_ENABLED.key -> "true",
+       config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false",
       // just to avoid this test taking too long
       "spark.locality.wait" -> "10ms"
     )
@@ -103,6 +106,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
     "SPARK-15865 Progress with fewer executors than maxTaskFailures",
     extraConfs = Seq(
       config.BLACKLIST_ENABLED.key -> "true",
+      config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false",
       "spark.testing.nHosts" -> "2",
       "spark.testing.nExecutorsPerHost" -> "1",
       "spark.testing.nCoresPerExecutor" -> "1"
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 520d85a298922..00139f5e1d57d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -41,6 +41,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
   override def beforeEach(): Unit = {
     conf = new SparkConf().setAppName("test").setMaster("local")
       .set(config.BLACKLIST_ENABLED.key, "true")
+      .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false")
     scheduler = mockTaskSchedWithConf(conf)
 
     clock.setTime(0)
@@ -188,7 +189,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures)
     assert(blacklist.nodeBlacklist() === Set("hostA"))
     assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA"))
-    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2))
+    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA",
+      ExecutorFailures(Set("1", "2"))))
     assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2"))
     verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 4))
 
@@ -202,7 +204,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set())
     verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "2"))
     verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "1"))
-    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA"))
+    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA",
+      BlacklistTimedOut))
 
     // Fail one more task, but executor isn't put back into blacklist since the count of failures
     // on that executor should have been reset to 0.
@@ -248,7 +251,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     assert(blacklist.isExecutorBlacklisted("2"))
     verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(t1, "2", 4))
     assert(blacklist.isNodeBlacklisted("hostA"))
-    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(t1, "hostA", 2))
+    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(t1, "hostA",
+      ExecutorFailures(Set("1", "2"))))
 
     // Advance the clock so that executor 1 should no longer be explicitly blacklisted, but
     // everything else should still be blacklisted.
@@ -266,7 +270,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     clock.advance(t1)
     blacklist.applyBlacklistTimeout()
     assert(!blacklist.nodeIdToBlacklistExpiryTime.contains("hostA"))
-    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(t1 + t2 + t1, "hostA"))
+    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(t1 + t2 + t1, "hostA",
+      BlacklistTimedOut))
     // Even though unblacklisting a node implicitly unblacklists all of its executors,
     // there will be no SparkListenerExecutorUnblacklisted sent here.
   }
@@ -401,14 +406,15 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3"))
     verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2))
     assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA"))
-    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2))
+    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA",
+      ExecutorFailures(Set("1", "3"))))
   }
 
   test("blacklist still respects legacy configs") {
     val conf = new SparkConf().setMaster("local")
-    assert(!BlacklistTracker.isBlacklistEnabled(conf))
+    assert(!BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf))
     conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 5000L)
-    assert(BlacklistTracker.isBlacklistEnabled(conf))
+    assert(BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf))
     assert(5000 === BlacklistTracker.getBlacklistTimeout(conf))
     // the new conf takes precedence, though
     conf.set(config.BLACKLIST_TIMEOUT_CONF, 1000L)
@@ -416,10 +422,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
 
     // if you explicitly set the legacy conf to 0, that also would disable blacklisting
     conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 0L)
-    assert(!BlacklistTracker.isBlacklistEnabled(conf))
+    assert(!BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf))
     // but again, the new conf takes precedence
     conf.set(config.BLACKLIST_ENABLED, true)
-    assert(BlacklistTracker.isBlacklistEnabled(conf))
+    assert(BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf))
     assert(1000 === BlacklistTracker.getBlacklistTimeout(conf))
   }
 
@@ -439,7 +445,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
         s"( = ${maxTaskFailures} ).  Though blacklisting is enabled, with this configuration, " +
         s"Spark will not be robust to one bad node.  Decrease " +
         s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " +
-        s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}")
+        s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}, " +
+        s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key} " +
+        s"and ${config.BLACKLIST_FETCH_FAILURE_ENABLED.key}")
     }
 
     conf.remove(config.MAX_TASK_FAILURES)
@@ -452,7 +460,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
       config.MAX_FAILED_EXEC_PER_NODE_STAGE,
       config.MAX_FAILURES_PER_EXEC,
       config.MAX_FAILED_EXEC_PER_NODE,
-      config.BLACKLIST_TIMEOUT_CONF
+      config.BLACKLIST_TIMEOUT_CONF,
+      config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF
     ).foreach { config =>
       conf.set(config.key, "0")
       val excMsg = intercept[IllegalArgumentException] {
@@ -585,4 +594,72 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
       2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
     assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
   }
+
+  test("node is blacklisted with NodeDecommissioning reason and gets recovered with time") {
+    conf.set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "true")
+    blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock)
+    blacklist.addNodeToBlacklist("hostA", NodeDecommissioning)
+    assert(blacklist.nodeBlacklist() === Set("hostA"))
+    assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA"))
+    verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", NodeDecommissioning))
+
+    val timeout = blacklist.BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS + 1
+    clock.advance(timeout)
+    blacklist.applyBlacklistTimeout()
+    assert(blacklist.nodeBlacklist() === Set())
+    assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set())
+    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA",
+      BlacklistTimedOut))
+  }
+
+  test("node is unblacklisted with NodeRunning reason") {
+    conf.set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "true")
+    blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock)
+    val now = clock.getTimeMillis()
+    blacklist.addNodeToBlacklist("hostA", NodeDecommissioning)
+    blacklist.addNodeToBlacklist("hostB", ExecutorFailures(Set()))
+    assert(blacklist.nodeBlacklist() === Set("hostA", "hostB"))
+    assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA", "hostB"))
+
+    blacklist.removeNodesFromBlacklist(List(("hostA", NodeRunning, Some(now))))
+    assert(blacklist.nodeBlacklist() === Set("hostB"))
+    assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostB"))
+    verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(now, "hostA", NodeRunning))
+  }
+
+  (for {
+    taskExecutionBlacklistingEnabled <- Seq(true, false)
+    decommissioningBlacklistingEnabled <- Seq(true, false)
+  } yield (taskExecutionBlacklistingEnabled, decommissioningBlacklistingEnabled)).foreach {
+    case (taskExecutionBlacklistingEnabled, decommissioningBlacklistingEnabled) =>
+      val blacklistStatusMsgDict = Map(true -> "enforced", false -> "ignored")
+
+      test(s"task execution blacklisting is " +
+          s"${blacklistStatusMsgDict(taskExecutionBlacklistingEnabled)} due to " +
+          s"${config.BLACKLIST_ENABLED.key}=$taskExecutionBlacklistingEnabled, while " +
+          s"decommissioning blacklisting is " +
+          s"${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)} " +
+          s"due to " +
+          s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key}=$decommissioningBlacklistingEnabled") {
+        conf = conf.set(config.BLACKLIST_ENABLED.key, taskExecutionBlacklistingEnabled.toString).
+          set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key,
+              decommissioningBlacklistingEnabled.toString)
+        blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock)
+
+        val (failingHost, decommissioningHost) = ("hostFailing", "hostDecommissioning")
+        blacklist.addNodeToBlacklist(failingHost, ExecutorFailures(Set()))
+        blacklist.addNodeToBlacklist(decommissioningHost, NodeDecommissioning)
+        val blacklistedHosts = (if (taskExecutionBlacklistingEnabled) {
+          Set(failingHost)
+        } else {
+          Set[String]()
+        }) ++ (if (decommissioningBlacklistingEnabled) {
+          Set(decommissioningHost)
+        } else {
+          Set[String]()
+        })
+        assert(blacklist.nodeBlacklist() === blacklistedHosts)
+        assertEquivalentToSet(blacklist.isNodeBlacklisted(_), blacklistedHosts)
+      }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index 04cccc67e328e..effc788d454b7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -17,13 +17,46 @@
 
 package org.apache.spark.scheduler
 
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.mockito.Mockito.{verify, when}
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.{HostState, LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
 import org.apache.spark.util.{RpcUtils, SerializableBuffer}
 
-class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext {
+class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with MockitoSugar
+  with LocalSparkContext {
+
+  private var conf: SparkConf = _
+  private var scheduler: TaskSchedulerImpl = _
+  private var schedulerBackend: CoarseGrainedSchedulerBackend = _
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf
+  }
+
+  override def afterEach(): Unit = {
+    super.afterEach()
+    if (scheduler != null) {
+      scheduler.stop()
+      scheduler = null
+    }
+    if (schedulerBackend != null) {
+      schedulerBackend.stop()
+      schedulerBackend = null
+    }
+  }
+
+  private def setupSchedulerBackend(): Unit = {
+    sc = new SparkContext("local", "test", conf)
+    scheduler = mock[TaskSchedulerImpl]
+    when(scheduler.sc).thenReturn(sc)
+    schedulerBackend = new CoarseGrainedSchedulerBackend(scheduler, mock[RpcEnv])
+  }
 
   test("serialized task larger than max RPC message size") {
-    val conf = new SparkConf
     conf.set("spark.rpc.message.maxSize", "1")
     conf.set("spark.default.parallelism", "1")
     sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf)
@@ -38,4 +71,13 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo
     assert(smaller.size === 4)
   }
 
+  test("handle updated node status received") {
+    setupSchedulerBackend()
+    schedulerBackend.handleUpdatedHostState("host1", HostState.Decommissioning)
+    verify(scheduler).blacklistExecutorsOnHost("host1", NodeDecommissioning)
+
+    schedulerBackend.handleUpdatedHostState("host1", HostState.Running)
+    verify(scheduler).unblacklistExecutorsOnHost("host1", NodeRunning)
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index ab67a393e2ac5..8c631d1d466b8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -85,6 +85,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     blacklist = mock[BlacklistTracker]
     val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite")
     conf.set(config.BLACKLIST_ENABLED, true)
+        .set(config.BLACKLIST_DECOMMISSIONING_ENABLED, false)
     sc = new SparkContext(conf)
     taskScheduler =
       new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) {
@@ -621,7 +622,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     // schedulable on another executor.  However, that executor may fail later on, leaving the
     // first task with no place to run.
     val taskScheduler = setupScheduler(
-      config.BLACKLIST_ENABLED.key -> "true"
+      config.BLACKLIST_ENABLED.key -> "true",
+      config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false"
     )
 
     val taskSet = FakeTask.createTaskSet(2)
@@ -672,7 +674,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     // available and not bail on the job
 
     val taskScheduler = setupScheduler(
-      config.BLACKLIST_ENABLED.key -> "true"
+      config.BLACKLIST_ENABLED.key -> "true",
+      config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false"
     )
 
     val taskSet = FakeTask.createTaskSet(2, (0 until 2).map { _ => Seq(TaskLocation("host0")) }: _*)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
index f1392e9db6bfd..44a7d60929259 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
@@ -25,6 +25,7 @@ class TaskSetBlacklistSuite extends SparkFunSuite {
   test("Blacklisting tasks, executors, and nodes") {
     val conf = new SparkConf().setAppName("test").setMaster("local")
       .set(config.BLACKLIST_ENABLED.key, "true")
+      .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false")
     val clock = new ManualClock
 
     val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock)
@@ -146,6 +147,7 @@ class TaskSetBlacklistSuite extends SparkFunSuite {
     // lead to any node blacklisting
     val conf = new SparkConf().setAppName("test").setMaster("local")
       .set(config.BLACKLIST_ENABLED.key, "true")
+      .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false")
     val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock())
     taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0)
     taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 6f1663b210969..425db11598923 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -417,102 +417,108 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     }
   }
 
-  test("executors should be blacklisted after task failure, in spite of locality preferences") {
-    val rescheduleDelay = 300L
-    val conf = new SparkConf().
-      set(config.BLACKLIST_ENABLED, true).
-      set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay).
-      // don't wait to jump locality levels in this test
-      set("spark.locality.wait", "0")
-
-    sc = new SparkContext("local", "test", conf)
-    // two executors on same host, one on different.
-    sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
-      ("exec1.1", "host1"), ("exec2", "host2"))
-    // affinity to exec1 on host1 - which we will fail.
-    val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1")))
-    val clock = new ManualClock
-    clock.advance(1)
-    // We don't directly use the application blacklist, but its presence triggers blacklisting
-    // within the taskset.
-    val mockListenerBus = mock(classOf[LiveListenerBus])
-    val blacklistTrackerOpt = Some(new BlacklistTracker(mockListenerBus, conf, None, clock))
-    val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock)
-
-    {
-      val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)
-      assert(offerResult.isDefined, "Expect resource offer to return a task")
+  List(true, false).foreach { decommissioningBlacklistingEnabled =>
+    val blacklistStatusMsgDict = Map(true -> "enabled", false -> "disabled")
+    test("executors should be blacklisted after task failure, in spite of locality preferences " +
+      s"and decommissioning blacklisting " +
+      s"being ${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)}") {
+      val rescheduleDelay = 300L
+      val conf = new SparkConf().
+        set(config.BLACKLIST_ENABLED, true).
+        set(config.BLACKLIST_DECOMMISSIONING_ENABLED, decommissioningBlacklistingEnabled).
+        set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay).
+        // don't wait to jump locality levels in this test
+        set("spark.locality.wait", "0")
+
+      sc = new SparkContext("local", "test", conf)
+      // two executors on same host, one on different.
+      sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
+        ("exec1.1", "host1"), ("exec2", "host2"))
+      // affinity to exec1 on host1 - which we will fail.
+      val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1")))
+      val clock = new ManualClock
+      clock.advance(1)
+      // We don't directly use the application blacklist, but its presence triggers blacklisting
+      // within the taskset.
+      val mockListenerBus = mock(classOf[LiveListenerBus])
+      val blacklistTrackerOpt = Some(new BlacklistTracker(mockListenerBus, conf, None, clock))
+      val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock)
+
+      {
+        val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)
+        assert(offerResult.isDefined, "Expect resource offer to return a task")
+
+        assert(offerResult.get.index === 0)
+        assert(offerResult.get.executorId === "exec1")
+
+        // Cause exec1 to fail : failure 1
+        manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
+        assert(!sched.taskSetsFailed.contains(taskSet.id))
 
-      assert(offerResult.get.index === 0)
-      assert(offerResult.get.executorId === "exec1")
+        // Ensure scheduling on exec1 fails after failure 1 due to blacklist
+        assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty)
+        assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty)
+        assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty)
+        assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty)
+      }
 
-      // Cause exec1 to fail : failure 1
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
-      assert(!sched.taskSetsFailed.contains(taskSet.id))
+      // Run the task on exec1.1 - should work, and then fail it on exec1.1
+      {
+        val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL)
+        assert(offerResult.isDefined,
+          "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult)
 
-      // Ensure scheduling on exec1 fails after failure 1 due to blacklist
-      assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty)
-      assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty)
-      assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty)
-      assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty)
-    }
+        assert(offerResult.get.index === 0)
+        assert(offerResult.get.executorId === "exec1.1")
 
-    // Run the task on exec1.1 - should work, and then fail it on exec1.1
-    {
-      val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL)
-      assert(offerResult.isDefined,
-        "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult)
+        // Cause exec1.1 to fail : failure 2
+        manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
+        assert(!sched.taskSetsFailed.contains(taskSet.id))
 
-      assert(offerResult.get.index === 0)
-      assert(offerResult.get.executorId === "exec1.1")
+        // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist
+        assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty)
+      }
 
-      // Cause exec1.1 to fail : failure 2
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
-      assert(!sched.taskSetsFailed.contains(taskSet.id))
+      // Run the task on exec2 - should work, and then fail it on exec2
+      {
+        val offerResult = manager.resourceOffer("exec2", "host2", ANY)
+        assert(offerResult.isDefined, "Expect resource offer to return a task")
 
-      // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist
-      assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty)
-    }
+        assert(offerResult.get.index === 0)
+        assert(offerResult.get.executorId === "exec2")
 
-    // Run the task on exec2 - should work, and then fail it on exec2
-    {
-      val offerResult = manager.resourceOffer("exec2", "host2", ANY)
-      assert(offerResult.isDefined, "Expect resource offer to return a task")
+        // Cause exec2 to fail : failure 3
+        manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
+        assert(!sched.taskSetsFailed.contains(taskSet.id))
 
-      assert(offerResult.get.index === 0)
-      assert(offerResult.get.executorId === "exec2")
+        // Ensure scheduling on exec2 fails after failure 3 due to blacklist
+        assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty)
+      }
 
-      // Cause exec2 to fail : failure 3
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
-      assert(!sched.taskSetsFailed.contains(taskSet.id))
+      // Despite advancing beyond the time for expiring executors from within the blacklist,
+      // we *never* expire from *within* the stage blacklist
+      clock.advance(rescheduleDelay)
 
-      // Ensure scheduling on exec2 fails after failure 3 due to blacklist
-      assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty)
-    }
+      {
+        val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)
+        assert(offerResult.isEmpty)
+      }
 
-    // Despite advancing beyond the time for expiring executors from within the blacklist,
-    // we *never* expire from *within* the stage blacklist
-    clock.advance(rescheduleDelay)
+      {
+        val offerResult = manager.resourceOffer("exec3", "host3", ANY)
+        assert(offerResult.isDefined)
+        assert(offerResult.get.index === 0)
+        assert(offerResult.get.executorId === "exec3")
 
-    {
-      val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)
-      assert(offerResult.isEmpty)
-    }
+        assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty)
 
-    {
-      val offerResult = manager.resourceOffer("exec3", "host3", ANY)
-      assert(offerResult.isDefined)
-      assert(offerResult.get.index === 0)
-      assert(offerResult.get.executorId === "exec3")
-
-      assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty)
+        // Cause exec3 to fail : failure 4
+        manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
+      }
 
-      // Cause exec3 to fail : failure 4
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
+      // we have failed the same task 4 times now : task id should now be in taskSetsFailed
+      assert(sched.taskSetsFailed.contains(taskSet.id))
     }
-
-    // we have failed the same task 4 times now : task id should now be in taskSetsFailed
-    assert(sched.taskSetsFailed.contains(taskSet.id))
   }
 
   test("new executors get added and lost") {
@@ -1100,44 +1106,85 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(manager3.name === "TaskSet_1.1")
   }
 
-  test("don't update blacklist for shuffle-fetch failures, preemption, denied commits, " +
-      "or killed tasks") {
-    // Setup a taskset, and fail some tasks for a fetch failure, preemption, denied commit,
-    // and killed task.
+  List(true, false).foreach { decommissioningBlacklistingEnabled =>
+    val blacklistStatusMsgDict = Map(true -> "enabled", false -> "disabled")
+    test("don't update blacklist for shuffle-fetch failures, preemption, denied commits, " +
+      "or killed tasks, in spite of decommissioning blacklisting " +
+      s"being ${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)}") {
+      // Setup a taskset, and fail some tasks for a fetch failure, preemption, denied commit,
+      // and killed task.
+      val conf = new SparkConf().
+        set(config.BLACKLIST_ENABLED, true).
+        set(config.BLACKLIST_DECOMMISSIONING_ENABLED, decommissioningBlacklistingEnabled)
+      sc = new SparkContext("local", "test", conf)
+      sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+      val taskSet = FakeTask.createTaskSet(4)
+      val tsm = new TaskSetManager(sched, taskSet, 4)
+      // we need a spy so we can attach our mock blacklist
+      val tsmSpy = spy(tsm)
+      val blacklist = mock(classOf[TaskSetBlacklist])
+      when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(blacklist))
+
+      // make some offers to our taskset, to get tasks we will fail
+      val taskDescs = Seq(
+        "exec1" -> "host1",
+        "exec2" -> "host1"
+      ).flatMap { case (exec, host) =>
+        // offer each executor twice (simulating 2 cores per executor)
+        (0 until 2).flatMap{ _ => tsmSpy.resourceOffer(exec, host, TaskLocality.ANY)}
+      }
+      assert(taskDescs.size === 4)
+
+      // now fail those tasks
+      tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED,
+        FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored"))
+      tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED,
+        ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None))
+      tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED,
+        TaskCommitDenied(0, 2, 0))
+      tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test"))
+
+      // Make sure that the blacklist ignored all of the task failures above, since they aren't
+      // the fault of the executor where the task was running.
+      verify(blacklist, never())
+        .updateBlacklistForFailedTask(anyString(), anyString(), anyInt())
+    }
+  }
+
+  test("don't update blacklist for successful task sets when task execution blacklisting is " +
+    "disabled, in spite of having decommissioning blacklisting enabled") {
     val conf = new SparkConf().
-      set(config.BLACKLIST_ENABLED, true)
+      set(config.BLACKLIST_ENABLED, false).
+      set(config.BLACKLIST_DECOMMISSIONING_ENABLED, true)
+
     sc = new SparkContext("local", "test", conf)
     sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
-    val taskSet = FakeTask.createTaskSet(4)
-    val tsm = new TaskSetManager(sched, taskSet, 4)
-    // we need a spy so we can attach our mock blacklist
-    val tsmSpy = spy(tsm)
-    val blacklist = mock(classOf[TaskSetBlacklist])
-    when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(blacklist))
+    val taskSet = FakeTask.createTaskSet(1)
+    val clock = new ManualClock
+    clock.advance(1)
+    val mockListenerBus = mock(classOf[LiveListenerBus])
+    // to simulate BLACKLIST_DECOMMISSIONING_ENABLED=true
+    val blacklistTrackerOpt = Some(spy(new BlacklistTracker(mockListenerBus, conf, None, clock)))
+    val tsm = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock)
 
-    // make some offers to our taskset, to get tasks we will fail
+    assert(tsm.taskSetBlacklistHelperOpt.isEmpty)
+    // make some offers to our taskset
     val taskDescs = Seq(
       "exec1" -> "host1",
       "exec2" -> "host1"
     ).flatMap { case (exec, host) =>
       // offer each executor twice (simulating 2 cores per executor)
-      (0 until 2).flatMap{ _ => tsmSpy.resourceOffer(exec, host, TaskLocality.ANY)}
+      (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)}
     }
-    assert(taskDescs.size === 4)
 
-    // now fail those tasks
-    tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED,
-      FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored"))
-    tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED,
-      ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None))
-    tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED,
-      TaskCommitDenied(0, 2, 0))
-    tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test"))
-
-    // Make sure that the blacklist ignored all of the task failures above, since they aren't
-    // the fault of the executor where the task was running.
-    verify(blacklist, never())
-      .updateBlacklistForFailedTask(anyString(), anyString(), anyInt())
+    val directTaskResult = new DirectTaskResult[String](null, Seq()) {
+      override def value(resultSer: SerializerInstance): String = ""
+    }
+    tsm.handleSuccessfulTask(taskDescs(0).taskId, directTaskResult)
+    tsm.abort("test")
+
+    verify(blacklistTrackerOpt.get, never())
+      .updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any())
   }
 
   test("update application blacklist for shuffle-fetch") {
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
index 46aa9c37986cc..f9be45426218c 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -32,6 +32,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContex
       .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
       .set(config.MAX_TASK_FAILURES, 1)
       .set(config.BLACKLIST_ENABLED, false)
+      .set(config.BLACKLIST_DECOMMISSIONING_ENABLED, false)
 
     val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
     conf.setJars(List(jar.getPath))
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index a1a858765a7d4..908d9325d1575 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -21,6 +21,7 @@ import java.util.Properties
 
 import scala.collection.JavaConverters._
 import scala.collection.Map
+import scala.collection.immutable.SortedSet
 
 import org.json4s.JsonAST.{JArray, JInt, JString, JValue}
 import org.json4s.JsonDSL._
@@ -85,9 +86,10 @@ class JsonProtocolSuite extends SparkFunSuite {
     val executorBlacklisted = SparkListenerExecutorBlacklisted(executorBlacklistedTime, "exec1", 22)
     val executorUnblacklisted =
       SparkListenerExecutorUnblacklisted(executorUnblacklistedTime, "exec1")
-    val nodeBlacklisted = SparkListenerNodeBlacklisted(nodeBlacklistedTime, "node1", 33)
+    val nodeBlacklisted = SparkListenerNodeBlacklisted(nodeBlacklistedTime, "host1",
+      ExecutorFailures(SortedSet("exec1", "exec2", "exec3")))
     val nodeUnblacklisted =
-      SparkListenerNodeUnblacklisted(nodeUnblacklistedTime, "node1")
+      SparkListenerNodeUnblacklisted(nodeUnblacklistedTime, "host1", BlacklistTimedOut)
     val executorMetricsUpdate = {
       // Use custom accum ID for determinism
       val accumUpdates =
@@ -169,6 +171,14 @@ class JsonProtocolSuite extends SparkFunSuite {
     testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure")))
     testTaskEndReason(UnknownReason)
 
+    // NodeBlacklistReason
+    testNodeBlacklistReason(ExecutorFailures(SortedSet("exec1", "exec2", "exec3")))
+    testNodeBlacklistReason(NodeDecommissioning)
+
+    // NodeUnblacklistReason
+    testNodeUnblacklistReason(BlacklistTimedOut)
+    testNodeUnblacklistReason(NodeRunning)
+
     // BlockId
     testBlockId(RDDBlockId(1, 2))
     testBlockId(ShuffleBlockId(1, 2, 3))
@@ -494,6 +504,18 @@ private[spark] object JsonProtocolSuite extends Assertions {
     assertEquals(reason, newReason)
   }
 
+  private def testNodeBlacklistReason(reason: NodeBlacklistReason) {
+    val newReason = JsonProtocol.nodeBlacklistReasonFromJson(
+      JsonProtocol.nodeBlacklistReasonToJson(reason))
+    assertEquals(reason, newReason)
+  }
+
+  private def testNodeUnblacklistReason(reason: NodeUnblacklistReason) {
+    val newReason = JsonProtocol.nodeUnblacklistReasonFromJson(
+      JsonProtocol.nodeUnblacklistReasonToJson(reason))
+    assertEquals(reason, newReason)
+  }
+
   private def testBlockId(blockId: BlockId) {
     val newBlockId = BlockId(blockId.toString)
     assert(blockId === newBlockId)
@@ -548,6 +570,14 @@ private[spark] object JsonProtocolSuite extends Assertions {
         assertEquals(e1.executorInfo, e2.executorInfo)
       case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) =>
         assert(e1.executorId === e1.executorId)
+      case (e1: SparkListenerNodeBlacklisted, e2: SparkListenerNodeBlacklisted) =>
+        assert(e1.hostId === e2.hostId)
+        assert(e1.time === e2.time)
+        assertEquals(e1.reason, e2.reason)
+      case (e1: SparkListenerNodeUnblacklisted, e2: SparkListenerNodeUnblacklisted) =>
+        assert(e1.hostId === e2.hostId)
+        assert(e1.time === e2.time)
+        assert(e1.reason === e2.reason)
       case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) =>
         assert(e1.execId === e2.execId)
         assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])](
@@ -693,6 +723,25 @@ private[spark] object JsonProtocolSuite extends Assertions {
     }
   }
 
+  private def assertEquals(reason1: NodeBlacklistReason, reason2: NodeBlacklistReason) {
+    (reason1, reason2) match {
+      case (NodeDecommissioning, NodeDecommissioning) =>
+      case (ExecutorFailures(blacklistedExecutors1), ExecutorFailures(blacklistedExecutors2)) =>
+        assert(blacklistedExecutors1 === blacklistedExecutors2)
+      case (FetchFailure(host1), FetchFailure(host2)) =>
+        assert(host1 === host2)
+      case _ => fail("Node blacklist reasons don't match in types!")
+    }
+  }
+
+  private def assertEquals(reason1: NodeUnblacklistReason, reason2: NodeUnblacklistReason) {
+    (reason1, reason2) match {
+      case (NodeRunning, NodeRunning) =>
+      case (BlacklistTimedOut, BlacklistTimedOut) =>
+      case _ => fail("Node unblacklist reasons don't match in types!")
+    }
+  }
+
   private def assertEquals(
       details1: Map[String, Seq[(String, String)]],
       details2: Map[String, Seq[(String, String)]]) {
@@ -2027,18 +2076,24 @@ private[spark] object JsonProtocolSuite extends Assertions {
   private val nodeBlacklistedJsonString =
     s"""
       |{
-      |  "Event" : "org.apache.spark.scheduler.SparkListenerNodeBlacklisted",
+      |  "Event" : "SparkListenerNodeBlacklisted",
+      |  "hostId" : "host1",
       |  "time" : ${nodeBlacklistedTime},
-      |  "hostId" : "node1",
-      |  "executorFailures" : 33
+      |  "blacklistReason" : {
+      |    "reason" : "ExecutorFailures",
+      |    "blacklistedExecutors" : [ "exec1", "exec2", "exec3" ]
+      |  }
       |}
     """.stripMargin
   private val nodeUnblacklistedJsonString =
     s"""
       |{
-      |  "Event" : "org.apache.spark.scheduler.SparkListenerNodeUnblacklisted",
+      |  "Event" : "SparkListenerNodeUnblacklisted",
+      |  "hostId" : "host1",
       |  "time" : ${nodeUnblacklistedTime},
-      |  "hostId" : "node1"
+      |  "unblacklistReason" : {
+      |    "reason" : "BlacklistTimedOut"
+      |  }
       |}
     """.stripMargin
 }
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 7052fb347106b..620f08d5cdef4 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -32,15 +32,15 @@ import org.apache.hadoop.yarn.client.api.AMRMClient
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 
-import org.apache.spark.{SecurityManager, SparkConf, SparkException}
+import org.apache.spark.{HostState, SecurityManager, SparkConf, SparkException}
 import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
 import org.apache.spark.deploy.yarn.config._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
 import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{HostStatusUpdate,
+RemoveExecutor, RetrieveLastAllocatedExecutorId}
 import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
 
 /**
@@ -266,6 +266,23 @@ private[yarn] class YarnAllocator(
     // requests.
     val allocateResponse = amClient.allocate(progressIndicator)
 
+    val updatedNodeReports = allocateResponse.getUpdatedNodes
+
+    updatedNodeReports.asScala.foreach(nodeReport => {
+      logInfo("Yarn node state updated for host %s to %s"
+        .format(nodeReport.getNodeId.getHost, nodeReport.getNodeState.name))
+
+      val hostState = HostState.fromYarnState(nodeReport.getNodeState.name)
+      hostState match {
+        case Some(state) =>
+          driverRef.send(HostStatusUpdate(nodeReport.getNodeId.getHost, state))
+
+        case None =>
+          logWarning("Cannot find Host state corresponding to YARN node state %s"
+            .format(nodeReport.getNodeState.name))
+      }
+    })
+
     val allocatedContainers = allocateResponse.getAllocatedContainers()
 
     if (allocatedContainers.size > 0) {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 415a29fd887e8..6acdb3d3aad38 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -264,6 +264,10 @@ private[spark] abstract class YarnSchedulerBackend(
       case AddWebUIFilter(filterName, filterParams, proxyBase) =>
         addWebUIFilter(filterName, filterParams, proxyBase)
 
+      case HostStatusUpdate(host, hostState) =>
+        logDebug("Received updated state %s for host %s".format(host, hostState))
+        handleUpdatedHostState(host, hostState)
+
       case r @ RemoveExecutor(executorId, reason) =>
         logWarning(reason.toString)
         driverEndpoint.ask[Boolean](r).onFailure {
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 9c3b18e4ec5f3..972849c9b8a7e 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -28,10 +28,12 @@ import scala.language.postfixOps
 
 import com.google.common.io.Files
 import org.apache.commons.lang3.SerializationUtils
+import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import org.apache.hadoop.yarn.server.MiniYARNCluster
 import org.scalatest.{BeforeAndAfterAll, Matchers}
 import org.scalatest.concurrent.Eventually._
+import org.scalatest.time
 
 import org.apache.spark._
 import org.apache.spark.deploy.yarn.config._
@@ -78,6 +80,28 @@ abstract class BaseYarnClusterSuite
     val logConfFile = new File(logConfDir, "log4j.properties")
     Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8)
 
+    restartCluster()
+
+    fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+    hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
+    assert(hadoopConfDir.mkdir())
+    File.createTempFile("token", ".txt", hadoopConfDir)
+  }
+
+  override def afterAll() {
+    try {
+      yarnCluster.stop()
+    } finally {
+      System.setProperties(oldSystemProperties)
+      super.afterAll()
+    }
+  }
+
+  protected def restartCluster(): Unit = {
+    if (yarnCluster != null) {
+      yarnCluster.stop()
+    }
+
     // Disable the disk utilization check to avoid the test hanging when people's disks are
     // getting full.
     val yarnConf = newYarnConfig()
@@ -113,20 +137,6 @@ abstract class BaseYarnClusterSuite
     }
 
     logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
-
-    fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
-    hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
-    assert(hadoopConfDir.mkdir())
-    File.createTempFile("token", ".txt", hadoopConfDir)
-  }
-
-  override def afterAll() {
-    try {
-      yarnCluster.stop()
-    } finally {
-      System.setProperties(oldSystemProperties)
-      super.afterAll()
-    }
   }
 
   protected def runSpark(
@@ -137,7 +147,9 @@ abstract class BaseYarnClusterSuite
       extraClassPath: Seq[String] = Nil,
       extraJars: Seq[String] = Nil,
       extraConf: Map[String, String] = Map(),
-      extraEnv: Map[String, String] = Map()): SparkAppHandle.State = {
+      extraEnv: Map[String, String] = Map(),
+      numExecutors: Int = 1,
+      executionTimeout: time.Span = 2 minutes): SparkAppHandle.State = {
     val deployMode = if (clientMode) "client" else "cluster"
     val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf)
     val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv
@@ -152,7 +164,7 @@ abstract class BaseYarnClusterSuite
     launcher.setSparkHome(sys.props("spark.test.home"))
       .setMaster("yarn")
       .setDeployMode(deployMode)
-      .setConf("spark.executor.instances", "1")
+      .setConf("spark.executor.instances", numExecutors.toString)
       .setPropertiesFile(propsFile)
       .addAppArgs(appArgs.toArray: _*)
 
@@ -167,7 +179,7 @@ abstract class BaseYarnClusterSuite
 
     val handle = launcher.startApplication()
     try {
-      eventually(timeout(2 minutes), interval(1 second)) {
+      eventually(timeout(executionTimeout), interval(1 second)) {
         assert(handle.getState().isFinal())
       }
     } finally {
@@ -238,4 +250,30 @@ abstract class BaseYarnClusterSuite
     propsFile.getAbsolutePath()
   }
 
+  protected def getClusterWorkDir: File = yarnCluster.getTestWorkDir
+
+  /**
+   * Gracefully decommissions the only node in the mini YARN cluster, if that
+   * functionality is available in the Hadoop version that it is configured.
+   * Throws an exception in case the decommissioning functionality is not available.
+   * @return the host that will be decommissioned.
+   */
+  protected def gracefullyDecommissionNode(conf: Configuration,
+                                           excludedHostsFile: File,
+                                           decommissionTimeout: time.Span): String = {
+    val resourceManager = yarnCluster.getResourceManager
+    val nodesListManager = resourceManager.getRMContext.getNodesListManager
+    val clusterNodes = Map(resourceManager.getRMContext.getRMNodes.asScala.toSeq : _ *)
+    assert(!clusterNodes.isEmpty)
+    // note the MiniYARNCluster will always have a single node
+    val hostToExclude = clusterNodes.keysIterator.next().getHost
+    Files.append(s"$hostToExclude ${decommissionTimeout.toSeconds}${sys.props("line.separator")}",
+                 excludedHostsFile, StandardCharsets.UTF_8)
+    // use reflection so this compiles for other YARN versions, but fails with a
+    // reflection exception if executed with incompatible versions of YARN
+    nodesListManager.getClass
+      .getMethod("refreshNodes", classOf[Configuration], java.lang.Boolean.TYPE)
+      .invoke(nodesListManager, conf, true.asInstanceOf[java.lang.Object])
+    hostToExclude
+  }
 }
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index cb1e3c5268510..3e7ae2ae95c50 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -20,19 +20,22 @@ package org.apache.spark.deploy.yarn
 import scala.collection.JavaConverters._
 
 import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse
 import org.apache.hadoop.yarn.api.records._
 import org.apache.hadoop.yarn.client.api.AMRMClient
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.mockito.{Matchers => MockitoMatchers}
 import org.mockito.Mockito._
 import org.scalatest.{BeforeAndAfterEach, Matchers}
 
-import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.{HostState, SecurityManager, SparkConf, SparkFunSuite}
 import org.apache.spark.deploy.yarn.YarnAllocator._
 import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
 import org.apache.spark.deploy.yarn.config._
 import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.scheduler.SplitInfo
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.HostStatusUpdate
 import org.apache.spark.util.ManualClock
 
 class MockResolver extends SparkRackResolver {
@@ -83,7 +86,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
 
   def createAllocator(
       maxExecutors: Int = 5,
-      rmClient: AMRMClient[ContainerRequest] = rmClient): YarnAllocator = {
+      rmClient: AMRMClient[ContainerRequest] = rmClient,
+      driverRef: RpcEndpointRef = mock(classOf[RpcEndpointRef])): YarnAllocator = {
     val args = Array(
       "--jar", "somejar.jar",
       "--class", "SomeClass")
@@ -94,7 +98,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
       .set("spark.executor.memory", "2048")
     new YarnAllocator(
       "not used",
-      mock(classOf[RpcEndpointRef]),
+      driverRef,
       conf,
       sparkConfClone,
       rmClient,
@@ -350,4 +354,37 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
     clock.advance(50 * 1000L)
     handler.getNumExecutorsFailed should be (0)
   }
+
+  test("HostStatusUpdate signal on YARN node state change") {
+    val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]])
+    val mockAllocateResponse = mock(classOf[AllocateResponse])
+    val mockNodeReport1 = mock(classOf[NodeReport])
+    val mockNodeReport2 = mock(classOf[NodeReport])
+    val mockNodeId1 = mock(classOf[NodeId])
+    val mockNodeId2 = mock(classOf[NodeId])
+
+    val nodeState1 = HostState.toYarnState(HostState.Decommissioning)
+    assert(nodeState1.isDefined)
+    val nodeState2 = HostState.toYarnState(HostState.Running)
+    assert(nodeState2.isDefined)
+
+    when(mockNodeId1.getHost).thenReturn("host1")
+    when(mockNodeId2.getHost).thenReturn("host2")
+    when(mockNodeReport1.getNodeState).thenReturn(NodeState.valueOf(nodeState1.get))
+    when(mockNodeReport2.getNodeState).thenReturn(NodeState.valueOf(nodeState2.get))
+    when(mockNodeReport1.getNodeId).thenReturn(mockNodeId1)
+    when(mockNodeReport2.getNodeId).thenReturn(mockNodeId2)
+
+    when(mockAllocateResponse.getUpdatedNodes).thenReturn(List(mockNodeReport1,
+      mockNodeReport2).asJava)
+    when(mockAmClient.allocate(MockitoMatchers.anyFloat())).thenReturn(mockAllocateResponse)
+
+    val driverRef = mock(classOf[RpcEndpointRef])
+    val handler = createAllocator(4, mockAmClient, driverRef)
+
+    handler.allocateResources()
+
+    verify(driverRef).send(HostStatusUpdate("host1", HostState.Decommissioning))
+    verify(driverRef).send(HostStatusUpdate("host2", HostState.Running))
+  }
 }
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala
new file mode 100644
index 0000000000000..b714f4a0e0bbd
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files => JFiles, Paths}
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future, Promise}
+import scala.concurrent.duration._
+import scala.io.Source
+import scala.language.postfixOps
+import scala.util.Try
+
+import com.google.common.io.Files
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.{BeforeAndAfter, Matchers}
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+
+import org.apache.spark.{HostState, SparkConf, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.tags.ExtendedYarnTest
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Integration test for YARN's graceful decommission mechanism; these tests use a mini
+ * Yarn cluster to run Spark-on-YARN applications, and require the Spark assembly to be built
+ * before they can be successfully run.
+ * Tests trigger the decommission of the only node in the mini Yarn cluster, and then check
+ * in the Yarn container logs that the Yarn node transitions were received at the driver.
+ */
+@ExtendedYarnTest
+class YarnDecommissioningSuite extends BaseYarnClusterSuite with BeforeAndAfter {
+
+  private val (excludedHostsFile, syncFile) = {
+    val (excludedHostsFile, syncFile) = (File.createTempFile("yarn-excludes", null, tempDir),
+                                         File.createTempFile("syncFile", null, tempDir))
+    excludedHostsFile.deleteOnExit()
+    syncFile.deleteOnExit()
+    logInfo(s"Using YARN excludes file ${excludedHostsFile.getAbsolutePath}")
+    logInfo(s"Using sync file ${syncFile.getAbsolutePath}")
+    (excludedHostsFile, syncFile)
+  }
+  // used to avoid restarting the MiniYARNCluster on the first test run
+  private var fistTestRun = true
+  private val executorService = Executors.newSingleThreadScheduledExecutor()
+  private implicit val ec = ExecutionContext.fromExecutorService(executorService)
+
+  override val newYarnConfig: YarnConfiguration = {
+    val conf = new YarnConfiguration()
+    conf.set("yarn.resourcemanager.nodes.exclude-path", excludedHostsFile.getAbsolutePath)
+    conf
+  }
+
+  private val decommissionStates = Set(HostState.Decommissioning,
+                                       HostState.Decommissioned).map{ state =>
+    val yarnStateOpt = HostState.toYarnState(state)
+    assert(yarnStateOpt.isDefined,
+           s"Spark host state $state should have a translation to YARN state")
+    yarnStateOpt.get
+  }
+
+  before {
+    if (!fistTestRun) {
+      Files.write("", excludedHostsFile, StandardCharsets.UTF_8)
+      Files.write("", syncFile, StandardCharsets.UTF_8)
+      restartCluster()
+    }
+    fistTestRun = false
+  }
+
+  test("Spark application master gets notified on node decommissioning when running in" +
+    " cluster mode") {
+    val excludedHostStateUpdates = testNodeDecommission(clientMode = false)
+    excludedHostStateUpdates shouldEqual(decommissionStates)
+  }
+
+  test("Spark application master gets notified on node decommissioning when running in" +
+    " client mode") {
+    val excludedHostStateUpdates = testNodeDecommission(clientMode = true)
+    // In client mode the node doesn't always have time to reach the decommissioned state
+    assert(excludedHostStateUpdates.subsetOf(decommissionStates))
+    assert(excludedHostStateUpdates
+             .contains(HostState.toYarnState(HostState.Decommissioning).get))
+  }
+
+  /**
+   * @return a set of strings for the Yarn decommission related states the only node in
+   *         the MiniYARNCluster has transitioned to after the Spark job has started.
+   */
+  private def testNodeDecommission(clientMode: Boolean): Set[String] = {
+    val excludedHostPromise = Promise[String]
+    scheduleDecommissionRunnable(excludedHostPromise)
+
+    // surface exceptions in the executor service
+    val excludedHostFuture = excludedHostPromise.future
+    excludedHostFuture.onFailure { case t => throw t  }
+    // we expect a timeout exception because the job will fail when the only available node
+    // is decommissioned after its timeout
+    intercept[TestFailedDueToTimeoutException] {
+      runSpark(clientMode, mainClassName(YarnDecommissioningDriver.getClass),
+        appArgs = Seq(syncFile.getAbsolutePath),
+        extraConf = Map(),
+        numExecutors = 2,
+        executionTimeout = 2 minutes)
+    }
+    assert(excludedHostPromise.isCompleted, "graceful decommission was not launched for any node")
+    val excludedHost = ThreadUtils.awaitResult(excludedHostFuture, 1 millisecond)
+    assert(excludedHost.length > 0)
+    getExcludedHostStateUpdate(excludedHost)
+  }
+
+  /**
+   * This method repeatedly schedules a task that checks the contents of the syncFile used to
+   * synchronize with the Spark driver. When the syncFile is updated with the sync text then
+   * YARN's graceful decommission mechanism is triggered, and the excluded host is returned
+   * by completing excludedHostPromise.
+   */
+  private def scheduleDecommissionRunnable(excludedHostPromise: Promise[String]): Unit = {
+    def decommissionRunnable(): Runnable = new Runnable() {
+      override def run() {
+        if (syncFile.exists() &&
+          Files.toString(syncFile, StandardCharsets.UTF_8)
+            .equals(YarnDecommissioningDriver.SYNC_TEXT)) {
+          excludedHostPromise.complete(Try{
+            logInfo("Launching graceful decommission of a node in YARN")
+            gracefullyDecommissionNode(newYarnConfig, excludedHostsFile,
+              decommissionTimeout = 10 seconds)
+          })
+        } else {
+          logDebug("Waiting for sync file to be updated by the driver")
+          executorService.schedule(decommissionRunnable(), 100, TimeUnit.MILLISECONDS)
+        }
+      }
+    }
+    executorService.schedule(decommissionRunnable(), 1, TimeUnit.SECONDS)
+  }
+
+  /**
+   * This method should be called after the Spark application has completed, to parse
+   * the container logs for messages about Yarn decommission related states involving
+   * the node that was decommissioned.
+   */
+  private def getExcludedHostStateUpdate(excludedHost: String): Set[String] = {
+    val stateChangeRe = {
+      val decommissionStateRe = decommissionStates.mkString("|")
+      "(?:%s.*(%s))|(?:(%s).*%s)".format(excludedHost, decommissionStateRe,
+        decommissionStateRe, excludedHost).r
+    }
+    (for {
+      file <- JFiles.walk(Paths.get(getClusterWorkDir.getAbsolutePath))
+                    .iterator().asScala
+      if file.getFileName.toString == "stderr"
+      line <- Source.fromFile(file.toFile).getLines()
+      matchingSubgroups <- stateChangeRe.findFirstMatchIn(line)
+                                        .map(_.subgroups.filter(_ != null)).toSeq
+      group <- matchingSubgroups
+    } yield group).toSet
+  }
+}
+
+private object YarnDecommissioningDriver extends Logging with Matchers {
+
+  val SYNC_TEXT = "First action completed. Start decommissioning."
+  val WAIT_TIMEOUT_MILLIS = 10000
+  val DECOMMISSION_WAIT_TIME_MILLIS = 500000
+
+  def main(args: Array[String]): Unit = {
+    if (args.length != 1) {
+      // scalastyle:off println
+      System.err.println(
+        s"""
+           |Invalid command line: ${args.mkString(" ")}
+           |
+        |Usage: YarnDecommissioningDriver [sync file]
+        """.stripMargin)
+      // scalastyle:on println
+      System.exit(1)
+    }
+    val sc = new SparkContext(new SparkConf()
+        .setAppName("Yarn Decommissioning Test"))
+    try {
+      logInfo("Starting YarnDecommissioningDriver")
+      val counts = sc.parallelize(1 to 10, 4)
+                      .map{ x => (x%7, x)}
+                      .reduceByKey(_ + _).collect
+      sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+      logInfo(s"Got ${counts.mkString(",")}")
+
+      val syncFile = new File(args(0))
+      Files.append(SYNC_TEXT, syncFile, StandardCharsets.UTF_8)
+      logInfo(s"Sync file ${syncFile} written")
+
+      // Wait for decommissioning and then for decommissioned, the timeout in
+      // the corresponding call to runSpark will interrupt this
+      Thread.sleep(DECOMMISSION_WAIT_TIME_MILLIS)
+    } catch {
+      case e =>
+        logError(s"Driver exception: ${e.getMessage}")
+    }
+    finally {
+      sc.stop()
+    }
+  }
+}
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org