You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2020/07/20 04:34:47 UTC

[spark] branch master updated: [SPARK-20629][CORE][K8S] Copy shuffle data when nodes are being shutdown

This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a4ca355  [SPARK-20629][CORE][K8S] Copy shuffle data when nodes are being shutdown
a4ca355 is described below

commit a4ca355af8556e8c5948e492ef70ef0b48416dc4
Author: Holden Karau <hk...@apple.com>
AuthorDate: Sun Jul 19 21:33:13 2020 -0700

    [SPARK-20629][CORE][K8S] Copy shuffle data when nodes are being shutdown
    
    ### What is changed?
    
    This pull request adds the ability to migrate shuffle files during Spark's decommissioning. The design document associated with this change is at https://docs.google.com/document/d/1xVO1b6KAwdUhjEJBolVPl9C6sLj7oOveErwDSYdT-pE .
    
    To allow this change the `MapOutputTracker` has been extended to allow the location of shuffle files to be updated with `updateMapOutput`. When a shuffle block is put, a block update message will be sent which triggers the `updateMapOutput`.
    
    Instead of rejecting remote puts of shuffle blocks `BlockManager` delegates the storage of shuffle blocks to it's shufflemanager's resolver (if supported). A new, experimental, trait is added for shuffle resolvers to indicate they handle remote putting of blocks.
    
    The existing block migration code is moved out into a separate file, and a producer/consumer model is introduced for migrating shuffle files from the host as quickly as possible while not overwhelming other executors.
    
    ### Why are the changes needed?
    
    Recomputting shuffle blocks can be expensive, we should take advantage of our decommissioning time to migrate these blocks.
    
    ### Does this PR introduce any user-facing change?
    
    This PR introduces two new configs parameters, `spark.storage.decommission.shuffleBlocks.enabled` & `spark.storage.decommission.rddBlocks.enabled` that control which blocks should be migrated during storage decommissioning.
    
    ### How was this patch tested?
    
    New unit test & expansion of the Spark on K8s decom test to assert that decommisioning with shuffle block migration means that the results are not recomputed even when the original executor is terminated.
    
    This PR is a cleaned-up version of the previous WIP PR I made https://github.com/apache/spark/pull/28331 (thanks to attilapiros for his very helpful reviewing on it :)).
    
    Closes #28708 from holdenk/SPARK-20629-copy-shuffle-data-when-nodes-are-being-shutdown-cleaned-up.
    
    Lead-authored-by: Holden Karau <hk...@apple.com>
    Co-authored-by: Holden Karau <ho...@pigscanfly.ca>
    Co-authored-by: “attilapiros” <pi...@gmail.com>
    Co-authored-by: Attila Zsolt Piros <at...@apiros-mbp16.lan>
    Signed-off-by: Holden Karau <hk...@apple.com>
---
 .../scala/org/apache/spark/MapOutputTracker.scala  |  38 ++-
 .../src/main/scala/org/apache/spark/SparkEnv.scala |   3 +-
 .../org/apache/spark/internal/config/package.scala |  23 ++
 .../network/netty/NettyBlockTransferService.scala  |   5 +-
 .../org/apache/spark/scheduler/MapStatus.scala     |  15 +-
 .../cluster/StandaloneSchedulerBackend.scala       |   2 +-
 .../spark/shuffle/IndexShuffleBlockResolver.scala  |  99 ++++++-
 .../apache/spark/shuffle/MigratableResolver.scala  |  48 +++
 .../apache/spark/shuffle/ShuffleBlockInfo.scala    |  28 ++
 .../scala/org/apache/spark/storage/BlockId.scala   |   5 +-
 .../org/apache/spark/storage/BlockManager.scala    | 154 +++-------
 .../spark/storage/BlockManagerDecommissioner.scala | 330 +++++++++++++++++++++
 .../spark/storage/BlockManagerMasterEndpoint.scala |  26 +-
 ...avedOnDecommissionedBlockManagerException.scala |  21 ++
 .../spark/scheduler/WorkerDecommissionSuite.scala  |   2 +-
 .../sort/IndexShuffleBlockResolverSuite.scala      |   3 +-
 .../org/apache/spark/storage/BlockIdSuite.scala    |   4 +-
 .../BlockManagerDecommissionIntegrationSuite.scala | 229 ++++++++++++++
 .../storage/BlockManagerDecommissionSuite.scala    | 106 -------
 .../BlockManagerDecommissionUnitSuite.scala        |  92 ++++++
 .../storage/BlockManagerReplicationSuite.scala     |   2 +-
 .../apache/spark/storage/BlockManagerSuite.scala   |  93 +++++-
 .../k8s/integrationtest/DecommissionSuite.scala    |  13 +-
 .../k8s/integrationtest/KubernetesSuite.scala      |  35 ++-
 .../integration-tests/tests/decommissioning.py     |  27 +-
 .../streaming/ReceivedBlockHandlerSuite.scala      |   2 +-
 26 files changed, 1150 insertions(+), 255 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 32251df..64102cc 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -49,7 +49,7 @@ import org.apache.spark.util._
  *
  * All public methods of this class are thread-safe.
  */
-private class ShuffleStatus(numPartitions: Int) {
+private class ShuffleStatus(numPartitions: Int) extends Logging {
 
   private val (readLock, writeLock) = {
     val lock = new ReentrantReadWriteLock()
@@ -122,11 +122,27 @@ private class ShuffleStatus(numPartitions: Int) {
   }
 
   /**
+   * Update the map output location (e.g. during migration).
+   */
+  def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
+    val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
+    mapStatusOpt match {
+      case Some(mapStatus) =>
+        logInfo(s"Updating map output for ${mapId} to ${bmAddress}")
+        mapStatus.updateLocation(bmAddress)
+        invalidateSerializedMapOutputStatusCache()
+      case None =>
+        logError(s"Asked to update map output ${mapId} for untracked map status.")
+    }
+  }
+
+  /**
    * Remove the map output which was served by the specified block manager.
    * This is a no-op if there is no registered map output or if the registered output is from a
    * different block manager.
    */
   def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
+    logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
     if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
       _numAvailableOutputs -= 1
       mapStatuses(mapIndex) = null
@@ -139,6 +155,7 @@ private class ShuffleStatus(numPartitions: Int) {
    * outputs which are served by an external shuffle server (if one exists).
    */
   def removeOutputsOnHost(host: String): Unit = withWriteLock {
+    logDebug(s"Removing outputs for host ${host}")
     removeOutputsByFilter(x => x.host == host)
   }
 
@@ -148,6 +165,7 @@ private class ShuffleStatus(numPartitions: Int) {
    * still registered with that execId.
    */
   def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
+    logDebug(s"Removing outputs for execId ${execId}")
     removeOutputsByFilter(x => x.executorId == execId)
   }
 
@@ -265,7 +283,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
     case GetMapOutputStatuses(shuffleId: Int) =>
       val hostPort = context.senderAddress.hostPort
-      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
+      logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
       tracker.post(new GetMapOutputMessage(shuffleId, context))
 
     case StopMapOutputTracker =>
@@ -465,6 +483,15 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.updateMapOutput(mapId, bmAddress)
+      case None =>
+        logError(s"Asked to update map output for unknown shuffle ${shuffleId}")
+    }
+  }
+
   def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
     shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
   }
@@ -745,7 +772,12 @@ private[spark] class MapOutputTrackerMaster(
   override def stop(): Unit = {
     mapOutputRequests.offer(PoisonPill)
     threadpool.shutdown()
-    sendTracker(StopMapOutputTracker)
+    try {
+      sendTracker(StopMapOutputTracker)
+    } catch {
+      case e: SparkException =>
+        logError("Could not tell tracker we are stopping.", e)
+    }
     trackerEndpoint = null
     shuffleStatuses.clear()
   }
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 8ba17398..d543359 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -367,7 +367,8 @@ object SparkEnv extends Logging {
             externalShuffleClient
           } else {
             None
-          }, blockManagerInfo)),
+          }, blockManagerInfo,
+          mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
       registerOrLookupEndpoint(
         BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
         new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index f0b292b..e1b598e 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -420,6 +420,29 @@ package object config {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED =
+    ConfigBuilder("spark.storage.decommission.shuffleBlocks.enabled")
+      .doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
+        "a migratable shuffle resolver (like sort based shuffe)")
+      .version("3.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS =
+    ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxThreads")
+      .doc("Maximum number of threads to use in migrating shuffle files.")
+      .version("3.1.0")
+      .intConf
+      .checkValue(_ > 0, "The maximum number of threads should be positive")
+      .createWithDefault(8)
+
+  private[spark] val STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED =
+    ConfigBuilder("spark.storage.decommission.rddBlocks.enabled")
+      .doc("Whether to transfer RDD blocks during block manager decommissioning.")
+      .version("3.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
   private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
     ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
       .internal()
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 3de7377..5d9cea0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -168,7 +168,10 @@ private[spark] class NettyBlockTransferService(
     // Everything else is encoded using our binary protocol.
     val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))
 
-    val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
+    // We always transfer shuffle blocks as a stream for simplicity with the receiving code since
+    // they are always written to disk. Otherwise we check the block size.
+    val asStream = (blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) ||
+      blockId.isShuffle)
     val callback = new RpcResponseCallback {
       override def onSuccess(response: ByteBuffer): Unit = {
         logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 7f8893f..0af3a2e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -30,12 +30,15 @@ import org.apache.spark.util.Utils
 
 /**
  * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
- * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
+ * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing
+ * on to the reduce tasks.
  */
 private[spark] sealed trait MapStatus {
-  /** Location where this task was run. */
+  /** Location where this task output is. */
   def location: BlockManagerId
 
+  def updateLocation(newLoc: BlockManagerId): Unit
+
   /**
    * Estimated size for the reduce block, in bytes.
    *
@@ -126,6 +129,10 @@ private[spark] class CompressedMapStatus(
 
   override def location: BlockManagerId = loc
 
+  override def updateLocation(newLoc: BlockManagerId): Unit = {
+    loc = newLoc
+  }
+
   override def getSizeForBlock(reduceId: Int): Long = {
     MapStatus.decompressSize(compressedSizes(reduceId))
   }
@@ -178,6 +185,10 @@ private[spark] class HighlyCompressedMapStatus private (
 
   override def location: BlockManagerId = loc
 
+  override def updateLocation(newLoc: BlockManagerId): Unit = {
+    loc = newLoc
+  }
+
   override def getSizeForBlock(reduceId: Int): Long = {
     assert(hugeBlockSizes != null)
     if (emptyBlocks.contains(reduceId)) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index ec1299a..4024b44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -44,7 +44,7 @@ private[spark] class StandaloneSchedulerBackend(
   with StandaloneAppClientListener
   with Logging {
 
-  private var client: StandaloneAppClient = null
+  private[spark] var client: StandaloneAppClient = null
   private val stopping = new AtomicBoolean(false)
   private val launcherBackend = new LauncherBackend() {
     override protected def conf: SparkConf = sc.conf
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index af2c82e..0d0dad6 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.shuffle
 
 import java.io._
+import java.nio.ByteBuffer
 import java.nio.channels.Channels
 import java.nio.file.Files
 
@@ -25,8 +26,10 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.client.StreamCallbackWithID
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.ExecutorDiskUtils
+import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -44,9 +47,10 @@ import org.apache.spark.util.Utils
 // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
 private[spark] class IndexShuffleBlockResolver(
     conf: SparkConf,
-    _blockManager: BlockManager = null)
+    // var for testing
+    var _blockManager: BlockManager = null)
   extends ShuffleBlockResolver
-  with Logging {
+  with Logging with MigratableResolver {
 
   private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
 
@@ -56,6 +60,19 @@ private[spark] class IndexShuffleBlockResolver(
   def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)
 
   /**
+   * Get the shuffle files that are stored locally. Used for block migrations.
+   */
+  override def getStoredShuffles(): Seq[ShuffleBlockInfo] = {
+    val allBlocks = blockManager.diskBlockManager.getAllBlocks()
+    allBlocks.flatMap {
+      case ShuffleIndexBlockId(shuffleId, mapId, _) =>
+        Some(ShuffleBlockInfo(shuffleId, mapId))
+      case _ =>
+        None
+    }
+  }
+
+  /**
    * Get the shuffle data file.
    *
    * When the dirs parameter is None then use the disk manager's local directories. Otherwise,
@@ -149,6 +166,82 @@ private[spark] class IndexShuffleBlockResolver(
   }
 
   /**
+   * Write a provided shuffle block as a stream. Used for block migrations.
+   * ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock.
+   * Requires the caller to delete any shuffle index blocks where the shuffle block fails to
+   * put.
+   */
+  override def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
+      StreamCallbackWithID = {
+    val file = blockId match {
+      case ShuffleIndexBlockId(shuffleId, mapId, _) =>
+        getIndexFile(shuffleId, mapId)
+      case ShuffleDataBlockId(shuffleId, mapId, _) =>
+        getDataFile(shuffleId, mapId)
+      case _ =>
+        throw new IllegalStateException(s"Unexpected shuffle block transfer ${blockId} as " +
+          s"${blockId.getClass().getSimpleName()}")
+    }
+    val fileTmp = Utils.tempFileWith(file)
+    val channel = Channels.newChannel(
+      serializerManager.wrapStream(blockId,
+        new FileOutputStream(fileTmp)))
+
+    new StreamCallbackWithID {
+
+      override def getID: String = blockId.name
+
+      override def onData(streamId: String, buf: ByteBuffer): Unit = {
+        while (buf.hasRemaining) {
+          channel.write(buf)
+        }
+      }
+
+      override def onComplete(streamId: String): Unit = {
+        logTrace(s"Done receiving shuffle block $blockId, now storing on local disk.")
+        channel.close()
+        val diskSize = fileTmp.length()
+        this.synchronized {
+          if (file.exists()) {
+            file.delete()
+          }
+          if (!fileTmp.renameTo(file)) {
+            throw new IOException(s"fail to rename file ${fileTmp} to ${file}")
+          }
+        }
+        blockManager.reportBlockStatus(blockId, BlockStatus(StorageLevel.DISK_ONLY, 0, diskSize))
+      }
+
+      override def onFailure(streamId: String, cause: Throwable): Unit = {
+        // the framework handles the connection itself, we just need to do local cleanup
+        logWarning(s"Error while uploading $blockId", cause)
+        channel.close()
+        fileTmp.delete()
+      }
+    }
+  }
+
+  /**
+   * Get the index & data block for migration.
+   */
+  def getMigrationBlocks(shuffleBlockInfo: ShuffleBlockInfo): List[(BlockId, ManagedBuffer)] = {
+    val shuffleId = shuffleBlockInfo.shuffleId
+    val mapId = shuffleBlockInfo.mapId
+    // Load the index block
+    val indexFile = getIndexFile(shuffleId, mapId)
+    val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+    val indexFileSize = indexFile.length()
+    val indexBlockData = new FileSegmentManagedBuffer(transportConf, indexFile, 0, indexFileSize)
+
+    // Load the data block
+    val dataFile = getDataFile(shuffleId, mapId)
+    val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+    val dataBlockData = new FileSegmentManagedBuffer(transportConf, dataFile, 0, dataFile.length())
+    List((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
+  }
+
+
+  /**
    * Write an index file with the offsets of each block, plus a final offset at the end for the
    * end of the output file. This will be used by getBlockData to figure out where each block
    * begins and ends.
@@ -169,7 +262,7 @@ private[spark] class IndexShuffleBlockResolver(
       val dataFile = getDataFile(shuffleId, mapId)
       // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
       // the following check and rename are atomic.
-      synchronized {
+      this.synchronized {
         val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
         if (existingLengths != null) {
           // Another attempt for the same task has already written our map outputs successfully,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala
new file mode 100644
index 0000000..3851fa6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.StreamCallbackWithID
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.storage.BlockId
+
+/**
+ * :: Experimental ::
+ * An experimental trait to allow Spark to migrate shuffle blocks.
+ */
+@Experimental
+@Since("3.1.0")
+trait MigratableResolver {
+  /**
+   * Get the shuffle ids that are stored locally. Used for block migrations.
+   */
+  def getStoredShuffles(): Seq[ShuffleBlockInfo]
+
+  /**
+   * Write a provided shuffle block as a stream. Used for block migrations.
+   */
+  def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
+      StreamCallbackWithID
+
+  /**
+   * Get the blocks for migration for a particular shuffle and map.
+   */
+  def getMigrationBlocks(shuffleBlockInfo: ShuffleBlockInfo): List[(BlockId, ManagedBuffer)]
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockInfo.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockInfo.scala
new file mode 100644
index 0000000..99ceee8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockInfo.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * An experimental case class used by MigratableResolver to return the shuffleId and mapId in a
+ * type safe way.
+ */
+@Experimental
+case class ShuffleBlockInfo(shuffleId: Int, mapId: Long)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 68ed3aa..7b084e7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -38,7 +38,10 @@ sealed abstract class BlockId {
   // convenience methods
   def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
   def isRDD: Boolean = isInstanceOf[RDDBlockId]
-  def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId]
+  def isShuffle: Boolean = {
+    (isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] ||
+     isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId])
+  }
   def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
 
   override def toString: String = name
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 6eec288..47af854 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -24,6 +24,7 @@ import java.nio.channels.Channels
 import java.util.Collections
 import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, TimeUnit}
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.collection.mutable.HashMap
 import scala.concurrent.{ExecutionContext, Future}
@@ -53,6 +54,7 @@ import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
+import org.apache.spark.shuffle.{MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter}
 import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter}
 import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock
 import org.apache.spark.storage.memory._
@@ -242,8 +244,8 @@ private[spark] class BlockManager(
 
   private var blockReplicationPolicy: BlockReplicationPolicy = _
 
-  private var blockManagerDecommissioning: Boolean = false
-  private var decommissionManager: Option[BlockManagerDecommissionManager] = None
+  // This is volatile since if it's defined we should not accept remote blocks.
+  @volatile private var decommissioner: Option[BlockManagerDecommissioner] = None
 
   // A DownloadFileManager used to track all the files of remote blocks which are above the
   // specified memory threshold. Files will be deleted automatically based on weak reference.
@@ -254,6 +256,15 @@ private[spark] class BlockManager(
 
   var hostLocalDirManager: Option[HostLocalDirManager] = None
 
+  @inline final private def isDecommissioning() = {
+    decommissioner.isDefined
+  }
+  // This is a lazy val so someone can migrating RDDs even if they don't have a MigratableResolver
+  // for shuffles. Used in BlockManagerDecommissioner & block puts.
+  private[storage] lazy val migratableResolver: MigratableResolver = {
+    shuffleManager.shuffleBlockResolver.asInstanceOf[MigratableResolver]
+  }
+
   /**
    * Abstraction for storing blocks from bytes, whether they start in memory or on disk.
    *
@@ -364,7 +375,7 @@ private[spark] class BlockManager(
             ThreadUtils.awaitReady(replicationFuture, Duration.Inf)
           } catch {
             case NonFatal(t) =>
-              throw new Exception("Error occurred while waiting for replication to finish", t)
+              throw new SparkException("Error occurred while waiting for replication to finish", t)
           }
         }
         if (blockWasSuccessfullyStored) {
@@ -617,6 +628,7 @@ private[spark] class BlockManager(
    */
   override def getLocalBlockData(blockId: BlockId): ManagedBuffer = {
     if (blockId.isShuffle) {
+      logInfo(s"Getting local shuffle block ${blockId}")
       shuffleManager.shuffleBlockResolver.getBlockData(blockId)
     } else {
       getLocalBytes(blockId) match {
@@ -650,6 +662,23 @@ private[spark] class BlockManager(
       blockId: BlockId,
       level: StorageLevel,
       classTag: ClassTag[_]): StreamCallbackWithID = {
+
+    if (isDecommissioning()) {
+       throw new BlockSavedOnDecommissionedBlockManagerException(blockId)
+    }
+
+    if (blockId.isShuffle) {
+      logDebug(s"Putting shuffle block ${blockId}")
+      try {
+        return migratableResolver.putShuffleBlockAsStream(blockId, serializerManager)
+      } catch {
+        case e: ClassCastException => throw new SparkException(
+          s"Unexpected shuffle block ${blockId} with unsupported shuffle " +
+          s"resolver ${shuffleManager.shuffleBlockResolver}")
+      }
+    }
+    logDebug(s"Putting regular block ${blockId}")
+    // All other blocks
     val (_, tmpFile) = diskBlockManager.createTempLocalBlock()
     val channel = new CountingWritableChannel(
       Channels.newChannel(serializerManager.wrapForEncryption(new FileOutputStream(tmpFile))))
@@ -720,7 +749,7 @@ private[spark] class BlockManager(
    * it is still valid). This ensures that update in master will compensate for the increase in
    * memory on the storage endpoint.
    */
-  private def reportBlockStatus(
+  private[spark] def reportBlockStatus(
       blockId: BlockId,
       status: BlockStatus,
       droppedMemorySize: Long = 0L): Unit = {
@@ -1285,6 +1314,9 @@ private[spark] class BlockManager(
 
     require(blockId != null, "BlockId is null")
     require(level != null && level.isValid, "StorageLevel is null or invalid")
+    if (isDecommissioning()) {
+      throw new BlockSavedOnDecommissionedBlockManagerException(blockId)
+    }
 
     val putBlockInfo = {
       val newInfo = new BlockInfo(level, classTag, tellMaster)
@@ -1540,7 +1572,7 @@ private[spark] class BlockManager(
   /**
    * Get peer block managers in the system.
    */
-  private def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = {
+  private[storage] def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = {
     peerFetchLock.synchronized {
       val cachedPeersTtl = conf.get(config.STORAGE_CACHED_PEERS_TTL) // milliseconds
       val diff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - lastPeerFetchTimeNs)
@@ -1779,59 +1811,19 @@ private[spark] class BlockManager(
     blocksToRemove.size
   }
 
-  def decommissionBlockManager(): Unit = {
-    if (!blockManagerDecommissioning) {
-      logInfo("Starting block manager decommissioning process")
-      blockManagerDecommissioning = true
-      decommissionManager = Some(new BlockManagerDecommissionManager(conf))
-      decommissionManager.foreach(_.start())
-    } else {
-      logDebug("Block manager already in decommissioning state")
+  def decommissionBlockManager(): Unit = synchronized {
+    decommissioner match {
+      case None =>
+        logInfo("Starting block manager decommissioning process...")
+        decommissioner = Some(new BlockManagerDecommissioner(conf, this))
+        decommissioner.foreach(_.start())
+      case Some(_) =>
+        logDebug("Block manager already in decommissioning state")
     }
   }
 
-  /**
-   * Tries to offload all cached RDD blocks from this BlockManager to peer BlockManagers
-   * Visible for testing
-   */
-  def decommissionRddCacheBlocks(): Unit = {
-    val replicateBlocksInfo = master.getReplicateInfoForRDDBlocks(blockManagerId)
-
-    if (replicateBlocksInfo.nonEmpty) {
-      logInfo(s"Need to replicate ${replicateBlocksInfo.size} blocks " +
-        "for block manager decommissioning")
-    } else {
-      logWarning(s"Asked to decommission RDD cache blocks, but no blocks to migrate")
-      return
-    }
-
-    // Maximum number of storage replication failure which replicateBlock can handle
-    val maxReplicationFailures = conf.get(
-      config.STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK)
-
-    // TODO: We can sort these blocks based on some policy (LRU/blockSize etc)
-    //   so that we end up prioritize them over each other
-    val blocksFailedReplication = replicateBlocksInfo.map {
-      case ReplicateBlock(blockId, existingReplicas, maxReplicas) =>
-        val replicatedSuccessfully = replicateBlock(
-          blockId,
-          existingReplicas.toSet,
-          maxReplicas,
-          maxReplicationFailures = Some(maxReplicationFailures))
-        if (replicatedSuccessfully) {
-          logInfo(s"Block $blockId offloaded successfully, Removing block now")
-          removeBlock(blockId)
-          logInfo(s"Block $blockId removed")
-        } else {
-          logWarning(s"Failed to offload block $blockId")
-        }
-        (blockId, replicatedSuccessfully)
-    }.filterNot(_._2).map(_._1)
-    if (blocksFailedReplication.nonEmpty) {
-      logWarning("Blocks failed replication in cache decommissioning " +
-        s"process: ${blocksFailedReplication.mkString(",")}")
-    }
-  }
+  private[storage] def getMigratableRDDBlocks(): Seq[ReplicateBlock] =
+    master.getReplicateInfoForRDDBlocks(blockManagerId)
 
   /**
    * Remove all blocks belonging to the given broadcast.
@@ -1901,58 +1893,8 @@ private[spark] class BlockManager(
     data.dispose()
   }
 
-  /**
-   * Class to handle block manager decommissioning retries
-   * It creates a Thread to retry offloading all RDD cache blocks
-   */
-  private class BlockManagerDecommissionManager(conf: SparkConf) {
-    @volatile private var stopped = false
-    private val sleepInterval = conf.get(
-      config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL)
-
-    private val blockReplicationThread = new Thread {
-      override def run(): Unit = {
-        var failures = 0
-        while (blockManagerDecommissioning
-          && !stopped
-          && !Thread.interrupted()
-          && failures < 20) {
-          try {
-            logDebug("Attempting to replicate all cached RDD blocks")
-            decommissionRddCacheBlocks()
-            logInfo("Attempt to replicate all cached blocks done")
-            Thread.sleep(sleepInterval)
-          } catch {
-            case _: InterruptedException =>
-              logInfo("Interrupted during migration, will not refresh migrations.")
-              stopped = true
-            case NonFatal(e) =>
-              failures += 1
-              logError("Error occurred while trying to replicate cached RDD blocks" +
-                s" for block manager decommissioning (failure count: $failures)", e)
-          }
-        }
-      }
-    }
-    blockReplicationThread.setDaemon(true)
-    blockReplicationThread.setName("block-replication-thread")
-
-    def start(): Unit = {
-      logInfo("Starting block replication thread")
-      blockReplicationThread.start()
-    }
-
-    def stop(): Unit = {
-      if (!stopped) {
-        stopped = true
-        logInfo("Stopping block replication thread")
-        blockReplicationThread.interrupt()
-      }
-    }
-  }
-
   def stop(): Unit = {
-    decommissionManager.foreach(_.stop())
+    decommissioner.foreach(_.stop())
     blockTransferService.close()
     if (blockStoreClient ne blockTransferService) {
       // Closing should be idempotent, but maybe not for the NioBlockTransferService.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala
new file mode 100644
index 0000000..1cc7ef6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.ExecutorService
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config
+import org.apache.spark.shuffle.{MigratableResolver, ShuffleBlockInfo}
+import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Class to handle block manager decommissioning retries.
+ * It creates a Thread to retry offloading all RDD cache and Shuffle blocks
+ */
+private[storage] class BlockManagerDecommissioner(
+    conf: SparkConf,
+    bm: BlockManager) extends Logging {
+
+  private val maxReplicationFailuresForDecommission =
+    conf.get(config.STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK)
+
+  /**
+   * This runnable consumes any shuffle blocks in the queue for migration. This part of a
+   * producer/consumer where the main migration loop updates the queue of blocks to be migrated
+   * periodically. On migration failure, the current thread will reinsert the block for another
+   * thread to consume. Each thread migrates blocks to a different particular executor to avoid
+   * distribute the blocks as quickly as possible without overwhelming any particular executor.
+   *
+   * There is no preference for which peer a given block is migrated to.
+   * This is notable different than the RDD cache block migration (further down in this file)
+   * which uses the existing priority mechanism for determining where to replicate blocks to.
+   * Generally speaking cache blocks are less impactful as they normally represent narrow
+   * transformations and we normally have less cache present than shuffle data.
+   *
+   * The producer/consumer model is chosen for shuffle block migration to maximize
+   * the chance of migrating all shuffle blocks before the executor is forced to exit.
+   */
+  private class ShuffleMigrationRunnable(peer: BlockManagerId) extends Runnable {
+    @volatile var running = true
+    override def run(): Unit = {
+      var migrating: Option[(ShuffleBlockInfo, Int)] = None
+      logInfo(s"Starting migration thread for ${peer}")
+      // Once a block fails to transfer to an executor stop trying to transfer more blocks
+      try {
+        while (running && !Thread.interrupted()) {
+          migrating = Option(shufflesToMigrate.poll())
+          migrating match {
+            case None =>
+              logDebug("Nothing to migrate")
+              // Nothing to do right now, but maybe a transfer will fail or a new block
+              // will finish being committed.
+              val SLEEP_TIME_SECS = 1
+              Thread.sleep(SLEEP_TIME_SECS * 1000L)
+            case Some((shuffleBlockInfo, retryCount)) =>
+              if (retryCount < maxReplicationFailuresForDecommission) {
+                logInfo(s"Trying to migrate shuffle ${shuffleBlockInfo} to ${peer}")
+                val blocks =
+                  bm.migratableResolver.getMigrationBlocks(shuffleBlockInfo)
+                logDebug(s"Got migration sub-blocks ${blocks}")
+                blocks.foreach { case (blockId, buffer) =>
+                  logDebug(s"Migrating sub-block ${blockId}")
+                  bm.blockTransferService.uploadBlockSync(
+                    peer.host,
+                    peer.port,
+                    peer.executorId,
+                    blockId,
+                    buffer,
+                    StorageLevel.DISK_ONLY,
+                    null)// class tag, we don't need for shuffle
+                  logDebug(s"Migrated sub block ${blockId}")
+                }
+                logInfo(s"Migrated ${shuffleBlockInfo} to ${peer}")
+              } else {
+                logError(s"Skipping block ${shuffleBlockInfo} because it has failed ${retryCount}")
+              }
+          }
+        }
+        // This catch is intentionally outside of the while running block.
+        // if we encounter errors migrating to an executor we want to stop.
+      } catch {
+        case e: Exception =>
+          migrating match {
+            case Some((shuffleMap, retryCount)) =>
+              logError(s"Error during migration, adding ${shuffleMap} back to migration queue", e)
+              shufflesToMigrate.add((shuffleMap, retryCount + 1))
+            case None =>
+              logError(s"Error while waiting for block to migrate", e)
+          }
+      }
+    }
+  }
+
+  // Shuffles which are either in queue for migrations or migrated
+  private val migratingShuffles = mutable.HashSet[ShuffleBlockInfo]()
+
+  // Shuffles which are queued for migration & number of retries so far.
+  private[storage] val shufflesToMigrate =
+    new java.util.concurrent.ConcurrentLinkedQueue[(ShuffleBlockInfo, Int)]()
+
+  // Set if we encounter an error attempting to migrate and stop.
+  @volatile private var stopped = false
+
+  private val migrationPeers =
+    mutable.HashMap[BlockManagerId, ShuffleMigrationRunnable]()
+
+  private lazy val rddBlockMigrationExecutor =
+    ThreadUtils.newDaemonSingleThreadExecutor("block-manager-decommission-rdd")
+
+  private val rddBlockMigrationRunnable = new Runnable {
+    val sleepInterval = conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL)
+
+    override def run(): Unit = {
+      assert(conf.get(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED))
+      while (!stopped && !Thread.interrupted()) {
+        logInfo("Iterating on migrating from the block manager.")
+        try {
+          logDebug("Attempting to replicate all cached RDD blocks")
+          decommissionRddCacheBlocks()
+          logInfo("Attempt to replicate all cached blocks done")
+          logInfo(s"Waiting for ${sleepInterval} before refreshing migrations.")
+          Thread.sleep(sleepInterval)
+        } catch {
+          case e: InterruptedException =>
+            logInfo("Interrupted during migration, will not refresh migrations.")
+            stopped = true
+          case NonFatal(e) =>
+            logError("Error occurred while trying to replicate for block manager decommissioning.",
+              e)
+            stopped = true
+        }
+      }
+    }
+  }
+
+  private lazy val shuffleBlockMigrationRefreshExecutor =
+    ThreadUtils.newDaemonSingleThreadExecutor("block-manager-decommission-shuffle")
+
+  private val shuffleBlockMigrationRefreshRunnable = new Runnable {
+    val sleepInterval = conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL)
+
+    override def run() {
+      assert(conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED))
+      while (!stopped && !Thread.interrupted()) {
+        try {
+          logDebug("Attempting to replicate all shuffle blocks")
+          refreshOffloadingShuffleBlocks()
+          logInfo("Done starting workers to migrate shuffle blocks")
+          Thread.sleep(sleepInterval)
+        } catch {
+          case e: InterruptedException =>
+            logInfo("Interrupted during migration, will not refresh migrations.")
+            stopped = true
+          case NonFatal(e) =>
+            logError("Error occurred while trying to replicate for block manager decommissioning.",
+              e)
+            stopped = true
+        }
+      }
+    }
+  }
+
+  lazy val shuffleMigrationPool = ThreadUtils.newDaemonCachedThreadPool(
+    "migrate-shuffles",
+    conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS))
+
+  /**
+   * Tries to offload all shuffle blocks that are registered with the shuffle service locally.
+   * Note: this does not delete the shuffle files in-case there is an in-progress fetch
+   * but rather shadows them.
+   * Requires an Indexed based shuffle resolver.
+   * Note: if called in testing please call stopOffloadingShuffleBlocks to avoid thread leakage.
+   */
+  private[storage] def refreshOffloadingShuffleBlocks(): Unit = {
+    // Update the queue of shuffles to be migrated
+    logInfo("Offloading shuffle blocks")
+    val localShuffles = bm.migratableResolver.getStoredShuffles().toSet
+    val newShufflesToMigrate = localShuffles.diff(migratingShuffles).toSeq
+    shufflesToMigrate.addAll(newShufflesToMigrate.map(x => (x, 0)).asJava)
+    migratingShuffles ++= newShufflesToMigrate
+
+    // Update the threads doing migrations
+    val livePeerSet = bm.getPeers(false).toSet
+    val currentPeerSet = migrationPeers.keys.toSet
+    val deadPeers = currentPeerSet.diff(livePeerSet)
+    val newPeers = livePeerSet.diff(currentPeerSet)
+    migrationPeers ++= newPeers.map { peer =>
+      logDebug(s"Starting thread to migrate shuffle blocks to ${peer}")
+      val runnable = new ShuffleMigrationRunnable(peer)
+      shuffleMigrationPool.submit(runnable)
+      (peer, runnable)
+    }
+    // A peer may have entered a decommissioning state, don't transfer any new blocks
+    deadPeers.foreach { peer =>
+        migrationPeers.get(peer).foreach(_.running = false)
+    }
+  }
+
+  /**
+   * Stop migrating shuffle blocks.
+   */
+  private[storage] def stopOffloadingShuffleBlocks(): Unit = {
+    logInfo("Stopping offloading shuffle blocks.")
+    // Stop as gracefully as possible.
+    migrationPeers.values.foreach{ _.running = false }
+    shuffleMigrationPool.shutdown()
+    shuffleMigrationPool.shutdownNow()
+  }
+
+  /**
+   * Tries to offload all cached RDD blocks from this BlockManager to peer BlockManagers
+   * Visible for testing
+   */
+  private[storage] def decommissionRddCacheBlocks(): Unit = {
+    val replicateBlocksInfo = bm.getMigratableRDDBlocks()
+
+    if (replicateBlocksInfo.nonEmpty) {
+      logInfo(s"Need to replicate ${replicateBlocksInfo.size} RDD blocks " +
+        "for block manager decommissioning")
+    } else {
+      logWarning(s"Asked to decommission RDD cache blocks, but no blocks to migrate")
+      return
+    }
+
+    // TODO: We can sort these blocks based on some policy (LRU/blockSize etc)
+    //   so that we end up prioritize them over each other
+    val blocksFailedReplication = replicateBlocksInfo.map { replicateBlock =>
+        val replicatedSuccessfully = migrateBlock(replicateBlock)
+        (replicateBlock.blockId, replicatedSuccessfully)
+    }.filterNot(_._2).map(_._1)
+    if (blocksFailedReplication.nonEmpty) {
+      logWarning("Blocks failed replication in cache decommissioning " +
+        s"process: ${blocksFailedReplication.mkString(",")}")
+    }
+  }
+
+  private def migrateBlock(blockToReplicate: ReplicateBlock): Boolean = {
+    val replicatedSuccessfully = bm.replicateBlock(
+      blockToReplicate.blockId,
+      blockToReplicate.replicas.toSet,
+      blockToReplicate.maxReplicas,
+      maxReplicationFailures = Some(maxReplicationFailuresForDecommission))
+    if (replicatedSuccessfully) {
+      logInfo(s"Block ${blockToReplicate.blockId} offloaded successfully, Removing block now")
+      bm.removeBlock(blockToReplicate.blockId)
+      logInfo(s"Block ${blockToReplicate.blockId} removed")
+    } else {
+      logWarning(s"Failed to offload block ${blockToReplicate.blockId}")
+    }
+    replicatedSuccessfully
+  }
+
+  def start(): Unit = {
+    logInfo("Starting block migration thread")
+    if (conf.get(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED)) {
+      rddBlockMigrationExecutor.submit(rddBlockMigrationRunnable)
+    }
+    if (conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)) {
+      shuffleBlockMigrationRefreshExecutor.submit(shuffleBlockMigrationRefreshRunnable)
+    }
+    if (!conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED) &&
+      !conf.get(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED)) {
+      logError(s"Storage decommissioning attempted but neither " +
+        s"${config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED.key} or " +
+        s"${config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED.key} is enabled ")
+      stopped = true
+    }
+  }
+
+  def stop(): Unit = {
+    if (stopped) {
+      return
+    } else {
+      stopped = true
+    }
+    try {
+      rddBlockMigrationExecutor.shutdown()
+    } catch {
+      case e: Exception =>
+        logError(s"Error during shutdown", e)
+    }
+    try {
+      shuffleBlockMigrationRefreshExecutor.shutdown()
+    } catch {
+      case e: Exception =>
+        logError(s"Error during shutdown", e)
+    }
+    try {
+      stopOffloadingShuffleBlocks()
+    } catch {
+      case e: Exception =>
+        logError(s"Error during shutdown", e)
+    }
+    logInfo("Forcing block migrations threads to stop")
+    try {
+      rddBlockMigrationExecutor.shutdownNow()
+    } catch {
+      case e: Exception =>
+        logError(s"Error during shutdown", e)
+    }
+    try {
+      shuffleBlockMigrationRefreshExecutor.shutdownNow()
+    } catch {
+      case e: Exception =>
+        logError(s"Error during shutdown", e)
+    }
+    logInfo("Stopped storage decommissioner")
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 2a48177..a3d4234 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -29,7 +29,7 @@ import scala.util.control.NonFatal
 
 import com.google.common.cache.CacheBuilder
 
-import org.apache.spark.SparkConf
+import org.apache.spark.{MapOutputTrackerMaster, SparkConf}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.network.shuffle.ExternalBlockStoreClient
@@ -50,7 +50,8 @@ class BlockManagerMasterEndpoint(
     conf: SparkConf,
     listenerBus: LiveListenerBus,
     externalBlockStoreClient: Option[ExternalBlockStoreClient],
-    blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo])
+    blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo],
+    mapOutputTracker: MapOutputTrackerMaster)
   extends IsolatedRpcEndpoint with Logging {
 
   // Mapping from executor id to the block manager's local disk directories.
@@ -162,7 +163,8 @@ class BlockManagerMasterEndpoint(
       context.reply(true)
 
     case DecommissionBlockManagers(executorIds) =>
-      decommissionBlockManagers(executorIds.flatMap(blockManagerIdByExecutor.get))
+      val bmIds = executorIds.flatMap(blockManagerIdByExecutor.get)
+      decommissionBlockManagers(bmIds)
       context.reply(true)
 
     case GetReplicateInfoForRDDBlocks(blockManagerId) =>
@@ -539,6 +541,24 @@ class BlockManagerMasterEndpoint(
       storageLevel: StorageLevel,
       memSize: Long,
       diskSize: Long): Boolean = {
+    logDebug(s"Updating block info on master ${blockId} for ${blockManagerId}")
+
+    if (blockId.isShuffle) {
+      blockId match {
+        case ShuffleIndexBlockId(shuffleId, mapId, _) =>
+          // Don't update the map output on just the index block
+          logDebug(s"Received shuffle index block update for ${shuffleId} ${mapId}, ignoring.")
+          return true
+        case ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) =>
+          logDebug(s"Received shuffle data block update for ${shuffleId} ${mapId}, updating.")
+          mapOutputTracker.updateMapOutput(shuffleId, mapId, blockManagerId)
+          return true
+        case _ =>
+          logDebug(s"Unexpected shuffle block type ${blockId}" +
+            s"as ${blockId.getClass().getSimpleName()}")
+          return false
+      }
+    }
 
     if (!blockManagerInfo.contains(blockManagerId)) {
       if (blockManagerId.isDriver && !isLocal) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockSavedOnDecommissionedBlockManagerException.scala b/core/src/main/scala/org/apache/spark/storage/BlockSavedOnDecommissionedBlockManagerException.scala
new file mode 100644
index 0000000..4684d9c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockSavedOnDecommissionedBlockManagerException.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+class BlockSavedOnDecommissionedBlockManagerException(blockId: BlockId)
+  extends Exception(s"Block $blockId cannot be saved on decommissioned executor")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala
index 148d20e..cd3ab4d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala
@@ -58,7 +58,7 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext {
     })
     TestUtils.waitUntilExecutorsUp(sc = sc,
       numExecutors = 2,
-      timeout = 10000) // 10s
+      timeout = 30000) // 30s
     val sleepyRdd = input.mapPartitions{ x =>
       Thread.sleep(5000) // 5s
       x
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index 27bb06b..725a1d9 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -27,7 +27,7 @@ import org.mockito.invocation.InvocationOnMock
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -48,6 +48,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
       (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
+    when(diskBlockManager.localDirs).thenReturn(Array(tempDir))
   }
 
   override def afterEach(): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index ef7b138..d7009e6 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -87,7 +87,7 @@ class BlockIdSuite extends SparkFunSuite {
     assert(id.shuffleId === 4)
     assert(id.mapId === 5)
     assert(id.reduceId === 6)
-    assert(!id.isShuffle)
+    assert(id.isShuffle)
     assertSame(id, BlockId(id.toString))
   }
 
@@ -100,7 +100,7 @@ class BlockIdSuite extends SparkFunSuite {
     assert(id.shuffleId === 7)
     assert(id.mapId === 8)
     assert(id.reduceId === 9)
-    assert(!id.isShuffle)
+    assert(id.isShuffle)
     assertSame(id, BlockId(id.toString))
   }
 
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala
new file mode 100644
index 0000000..afcb38b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.Semaphore
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark._
+import org.apache.spark.internal.config
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
+import org.apache.spark.util.{ResetSystemProperties, ThreadUtils}
+
+class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalSparkContext
+    with ResetSystemProperties with Eventually {
+
+  val numExecs = 3
+  val numParts = 3
+
+  test(s"verify that an already running task which is going to cache data succeeds " +
+    s"on a decommissioned executor") {
+    runDecomTest(true, false, true)
+  }
+
+  test(s"verify that shuffle blocks are migrated") {
+    runDecomTest(false, true, false)
+  }
+
+  test(s"verify that both migrations can work at the same time.") {
+    runDecomTest(true, true, false)
+  }
+
+  private def runDecomTest(persist: Boolean, shuffle: Boolean, migrateDuring: Boolean) = {
+
+    val master = s"local-cluster[${numExecs}, 1, 1024]"
+    val conf = new SparkConf().setAppName("test").setMaster(master)
+      .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true)
+      .set(config.STORAGE_DECOMMISSION_ENABLED, true)
+      .set(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED, persist)
+      .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, shuffle)
+      // Just replicate blocks as fast as we can during testing, there isn't another
+      // workload we need to worry about.
+      .set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL, 1L)
+
+    sc = new SparkContext(master, "test", conf)
+
+    // Wait for the executors to start
+    TestUtils.waitUntilExecutorsUp(sc = sc,
+      numExecutors = numExecs,
+      timeout = 60000) // 60s
+
+    val input = sc.parallelize(1 to numParts, numParts)
+    val accum = sc.longAccumulator("mapperRunAccumulator")
+    input.count()
+
+    // Create a new RDD where we have sleep in each partition, we are also increasing
+    // the value of accumulator in each partition
+    val baseRdd = input.mapPartitions { x =>
+      if (migrateDuring) {
+        Thread.sleep(1000)
+      }
+      accum.add(1)
+      x.map(y => (y, y))
+    }
+    val testRdd = shuffle match {
+      case true => baseRdd.reduceByKey(_ + _)
+      case false => baseRdd
+    }
+
+    // Listen for the job & block updates
+    val taskStartSem = new Semaphore(0)
+    val broadcastSem = new Semaphore(0)
+    val executorRemovedSem = new Semaphore(0)
+    val taskEndEvents = ArrayBuffer.empty[SparkListenerTaskEnd]
+    val blocksUpdated = ArrayBuffer.empty[SparkListenerBlockUpdated]
+    sc.addSparkListener(new SparkListener {
+
+      override def onExecutorRemoved(execRemoved: SparkListenerExecutorRemoved): Unit = {
+        executorRemovedSem.release()
+      }
+
+      override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+        taskStartSem.release()
+      }
+
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        taskEndEvents.append(taskEnd)
+      }
+
+      override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+        // Once broadcast start landing on the executors we're good to proceed.
+        // We don't only use task start as it can occur before the work is on the executor.
+        if (blockUpdated.blockUpdatedInfo.blockId.isBroadcast) {
+          broadcastSem.release()
+        }
+        blocksUpdated.append(blockUpdated)
+      }
+    })
+
+
+    // Cache the RDD lazily
+    if (persist) {
+      testRdd.persist()
+    }
+
+    // Start the computation of RDD - this step will also cache the RDD
+    val asyncCount = testRdd.countAsync()
+
+    // Wait for the job to have started.
+    taskStartSem.acquire(1)
+    // Wait for each executor + driver to have it's broadcast info delivered.
+    broadcastSem.acquire((numExecs + 1))
+
+    // Make sure the job is either mid run or otherwise has data to migrate.
+    if (migrateDuring) {
+      // Give Spark a tiny bit to start executing after the broadcast blocks land.
+      // For me this works at 100, set to 300 for system variance.
+      Thread.sleep(300)
+    } else {
+      ThreadUtils.awaitResult(asyncCount, 15.seconds)
+    }
+
+    // Decommission one of the executors.
+    val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend]
+    val execs = sched.getExecutorIds()
+    assert(execs.size == numExecs, s"Expected ${numExecs} executors but found ${execs.size}")
+
+    val execToDecommission = execs.head
+    logDebug(s"Decommissioning executor ${execToDecommission}")
+    sched.decommissionExecutor(execToDecommission)
+
+    // Wait for job to finish.
+    val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 15.seconds)
+    assert(asyncCountResult === numParts)
+    // All tasks finished, so accum should have been increased numParts times.
+    assert(accum.value === numParts)
+
+    sc.listenerBus.waitUntilEmpty()
+    if (shuffle) {
+      //  mappers & reducers which succeeded
+      assert(taskEndEvents.count(_.reason == Success) === 2 * numParts,
+        s"Expected ${2 * numParts} tasks got ${taskEndEvents.size} (${taskEndEvents})")
+    } else {
+      // only mappers which executed successfully
+      assert(taskEndEvents.count(_.reason == Success) === numParts,
+        s"Expected ${numParts} tasks got ${taskEndEvents.size} (${taskEndEvents})")
+    }
+
+    // Wait for our respective blocks to have migrated
+    eventually(timeout(30.seconds), interval(10.milliseconds)) {
+      if (persist) {
+        // One of our blocks should have moved.
+        val rddUpdates = blocksUpdated.filter { update =>
+          val blockId = update.blockUpdatedInfo.blockId
+          blockId.isRDD}
+        val blockLocs = rddUpdates.map { update =>
+          (update.blockUpdatedInfo.blockId.name,
+            update.blockUpdatedInfo.blockManagerId)}
+        val blocksToManagers = blockLocs.groupBy(_._1).mapValues(_.size)
+        assert(!blocksToManagers.filter(_._2 > 1).isEmpty,
+          s"We should have a block that has been on multiple BMs in rdds:\n ${rddUpdates} from:\n" +
+          s"${blocksUpdated}\n but instead we got:\n ${blocksToManagers}")
+      }
+      // If we're migrating shuffles we look for any shuffle block updates
+      // as there is no block update on the initial shuffle block write.
+      if (shuffle) {
+        val numDataLocs = blocksUpdated.filter { update =>
+          val blockId = update.blockUpdatedInfo.blockId
+          blockId.isInstanceOf[ShuffleDataBlockId]
+        }.size
+        val numIndexLocs = blocksUpdated.filter { update =>
+          val blockId = update.blockUpdatedInfo.blockId
+          blockId.isInstanceOf[ShuffleIndexBlockId]
+        }.size
+        assert(numDataLocs === 1, s"Expect shuffle data block updates in ${blocksUpdated}")
+        assert(numIndexLocs === 1, s"Expect shuffle index block updates in ${blocksUpdated}")
+      }
+    }
+
+    // Since the RDD is cached or shuffled so further usage of same RDD should use the
+    // cached data. Original RDD partitions should not be recomputed i.e. accum
+    // should have same value like before
+    assert(testRdd.count() === numParts)
+    assert(accum.value === numParts)
+
+    val storageStatus = sc.env.blockManager.master.getStorageStatus
+    val execIdToBlocksMapping = storageStatus.map(
+      status => (status.blockManagerId.executorId, status.blocks)).toMap
+    // No cached blocks should be present on executor which was decommissioned
+    assert(execIdToBlocksMapping(execToDecommission).keys.filter(_.isRDD).toSeq === Seq(),
+      "Cache blocks should be migrated")
+    if (persist) {
+      // There should still be all the RDD blocks cached
+      assert(execIdToBlocksMapping.values.flatMap(_.keys).count(_.isRDD) === numParts)
+    }
+
+    // Make the executor we decommissioned exit
+    sched.client.killExecutors(List(execToDecommission))
+
+    // Wait for the executor to be removed
+    executorRemovedSem.acquire(1)
+
+    // Since the RDD is cached or shuffled so further usage of same RDD should use the
+    // cached data. Original RDD partitions should not be recomputed i.e. accum
+    // should have same value like before
+    assert(testRdd.count() === numParts)
+    assert(accum.value === numParts)
+
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala
deleted file mode 100644
index 7456ca7..0000000
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.Semaphore
-
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.duration._
-
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, Success}
-import org.apache.spark.internal.config
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
-import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
-import org.apache.spark.util.{ResetSystemProperties, ThreadUtils}
-
-class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
-    with ResetSystemProperties {
-
-  override def beforeEach(): Unit = {
-    val conf = new SparkConf().setAppName("test")
-      .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true)
-      .set(config.STORAGE_DECOMMISSION_ENABLED, true)
-
-    sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf)
-  }
-
-  test(s"verify that an already running task which is going to cache data succeeds " +
-    s"on a decommissioned executor") {
-    // Create input RDD with 10 partitions
-    val input = sc.parallelize(1 to 10, 10)
-    val accum = sc.longAccumulator("mapperRunAccumulator")
-    // Do a count to wait for the executors to be registered.
-    input.count()
-
-    // Create a new RDD where we have sleep in each partition, we are also increasing
-    // the value of accumulator in each partition
-    val sleepyRdd = input.mapPartitions { x =>
-      Thread.sleep(500)
-      accum.add(1)
-      x
-    }
-
-    // Listen for the job
-    val sem = new Semaphore(0)
-    val taskEndEvents = ArrayBuffer.empty[SparkListenerTaskEnd]
-    sc.addSparkListener(new SparkListener {
-      override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
-       sem.release()
-      }
-
-      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-        taskEndEvents.append(taskEnd)
-      }
-    })
-
-    // Cache the RDD lazily
-    sleepyRdd.persist()
-
-    // Start the computation of RDD - this step will also cache the RDD
-    val asyncCount = sleepyRdd.countAsync()
-
-    // Wait for the job to have started
-    sem.acquire(1)
-
-    // Give Spark a tiny bit to start the tasks after the listener says hello
-    Thread.sleep(100)
-    // Decommission one of the executor
-    val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend]
-    val execs = sched.getExecutorIds()
-    assert(execs.size == 2, s"Expected 2 executors but found ${execs.size}")
-    val execToDecommission = execs.head
-    sched.decommissionExecutor(execToDecommission)
-
-    // Wait for job to finish
-    val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 6.seconds)
-    assert(asyncCountResult === 10)
-    // All 10 tasks finished, so accum should have been increased 10 times
-    assert(accum.value === 10)
-
-    // All tasks should be successful, nothing should have failed
-    sc.listenerBus.waitUntilEmpty()
-    assert(taskEndEvents.size === 10) // 10 mappers
-    assert(taskEndEvents.map(_.reason).toSet === Set(Success))
-
-    // Since the RDD is cached, so further usage of same RDD should use the
-    // cached data. Original RDD partitions should not be recomputed i.e. accum
-    // should have same value like before
-    assert(sleepyRdd.count() === 10)
-    assert(accum.value === 10)
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala
new file mode 100644
index 0000000..5ff1ff0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.concurrent.duration._
+
+import org.mockito.{ArgumentMatchers => mc}
+import org.mockito.Mockito.{mock, times, verify, when}
+import org.scalatest._
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.internal.config
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.shuffle.{MigratableResolver, ShuffleBlockInfo}
+import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock
+
+class BlockManagerDecommissionUnitSuite extends SparkFunSuite with Matchers {
+
+  private val bmPort = 12345
+
+  private val sparkConf = new SparkConf(false)
+    .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, true)
+    .set(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED, true)
+
+  private def registerShuffleBlocks(
+      mockMigratableShuffleResolver: MigratableResolver,
+      ids: Set[(Int, Long, Int)]): Unit = {
+
+    when(mockMigratableShuffleResolver.getStoredShuffles())
+      .thenReturn(ids.map(triple => ShuffleBlockInfo(triple._1, triple._2)).toSeq)
+
+    ids.foreach { case (shuffleId: Int, mapId: Long, reduceId: Int) =>
+      when(mockMigratableShuffleResolver.getMigrationBlocks(mc.any()))
+        .thenReturn(List(
+          (ShuffleIndexBlockId(shuffleId, mapId, reduceId), mock(classOf[ManagedBuffer])),
+          (ShuffleDataBlockId(shuffleId, mapId, reduceId), mock(classOf[ManagedBuffer]))))
+    }
+  }
+
+  test("test shuffle and cached rdd migration without any error") {
+    val blockTransferService = mock(classOf[BlockTransferService])
+    val bm = mock(classOf[BlockManager])
+
+    val storedBlockId1 = RDDBlockId(0, 0)
+    val storedBlock1 =
+      new ReplicateBlock(storedBlockId1, Seq(BlockManagerId("replicaHolder", "host1", bmPort)), 1)
+
+    val migratableShuffleBlockResolver = mock(classOf[MigratableResolver])
+    registerShuffleBlocks(migratableShuffleBlockResolver, Set((1, 1L, 1)))
+    when(bm.getPeers(mc.any()))
+      .thenReturn(Seq(BlockManagerId("exec2", "host2", 12345)))
+
+    when(bm.blockTransferService).thenReturn(blockTransferService)
+    when(bm.migratableResolver).thenReturn(migratableShuffleBlockResolver)
+    when(bm.getMigratableRDDBlocks())
+      .thenReturn(Seq(storedBlock1))
+
+    val bmDecomManager = new BlockManagerDecommissioner(sparkConf, bm)
+
+    try {
+      bmDecomManager.start()
+
+      eventually(timeout(5.second), interval(10.milliseconds)) {
+        assert(bmDecomManager.shufflesToMigrate.isEmpty == true)
+        verify(bm, times(1)).replicateBlock(
+          mc.eq(storedBlockId1), mc.any(), mc.any(), mc.eq(Some(3)))
+        verify(blockTransferService, times(2))
+          .uploadBlockSync(mc.eq("host2"), mc.eq(bmPort), mc.eq("exec2"), mc.any(), mc.any(),
+            mc.eq(StorageLevel.DISK_ONLY), mc.isNull())
+      }
+    } finally {
+        bmDecomManager.stop()
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 660bfcf..d18d84d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -103,7 +103,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
     val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]()
     master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        new LiveListenerBus(conf), None, blockManagerInfo)),
+        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)
     allStores.clear()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index dc1c7cd..62bb4d9 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
 
 import java.io.File
 import java.nio.ByteBuffer
+import java.nio.file.Files
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -50,10 +51,11 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo
 import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient}
 import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
-import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerBlockUpdated}
+import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated}
 import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend}
 import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager}
+import org.apache.spark.shuffle.{ShuffleBlockResolver, ShuffleManager}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.storage.BlockManagerMessages._
 import org.apache.spark.util._
@@ -61,7 +63,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer
 
 class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach
   with PrivateMethodTester with LocalSparkContext with ResetSystemProperties
-  with EncryptionFunSuite with TimeLimits {
+  with EncryptionFunSuite with TimeLimits with BeforeAndAfterAll {
 
   import BlockManagerSuite._
 
@@ -70,6 +72,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
 
   var conf: SparkConf = null
   val allStores = ArrayBuffer[BlockManager]()
+  val sortShuffleManagers = ArrayBuffer[SortShuffleManager]()
   var rpcEnv: RpcEnv = null
   var master: BlockManagerMaster = null
   var liveListenerBus: LiveListenerBus = null
@@ -97,12 +100,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       .set(Network.RPC_ASK_TIMEOUT, "5s")
   }
 
+  private def makeSortShuffleManager(): SortShuffleManager = {
+    val newMgr = new SortShuffleManager(new SparkConf(false))
+    sortShuffleManagers += newMgr
+    newMgr
+  }
+
   private def makeBlockManager(
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER,
       master: BlockManagerMaster = this.master,
       transferService: Option[BlockTransferService] = Option.empty,
-      testConf: Option[SparkConf] = None): BlockManager = {
+      testConf: Option[SparkConf] = None,
+      shuffleManager: ShuffleManager = shuffleManager): BlockManager = {
     val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf)
     bmConf.set(TEST_MEMORY, maxMem)
     bmConf.set(MEMORY_OFFHEAP_SIZE, maxMem)
@@ -153,7 +163,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     liveListenerBus = spy(new LiveListenerBus(conf))
     master = spy(new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        liveListenerBus, None, blockManagerInfo)),
+        liveListenerBus, None, blockManagerInfo, mapOutputTracker)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true))
 
@@ -166,6 +176,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       conf = null
       allStores.foreach(_.stop())
       allStores.clear()
+      sortShuffleManagers.foreach(_.stop())
+      sortShuffleManagers.clear()
       rpcEnv.shutdown()
       rpcEnv.awaitTermination()
       rpcEnv = null
@@ -176,6 +188,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     }
   }
 
+  override def afterAll(): Unit = {
+    try {
+      // Cleanup the reused items.
+      Option(bcastManager).foreach(_.stop())
+      Option(mapOutputTracker).foreach(_.stop())
+      Option(shuffleManager).foreach(_.stop())
+    } finally {
+      super.afterAll()
+    }
+  }
+
   private def stopBlockManager(blockManager: BlockManager): Unit = {
     allStores -= blockManager
     blockManager.stop()
@@ -1815,6 +1838,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     verify(liveListenerBus, never()).post(SparkListenerBlockUpdated(BlockUpdatedInfo(updateInfo)))
   }
 
+  test("we reject putting blocks when we have the wrong shuffle resolver") {
+    val badShuffleManager = mock(classOf[ShuffleManager])
+    val badShuffleResolver = mock(classOf[ShuffleBlockResolver])
+    when(badShuffleManager.shuffleBlockResolver).thenReturn(badShuffleResolver)
+    val shuffleBlockId = ShuffleDataBlockId(0, 0, 0)
+    val bm = makeBlockManager(100, "exec1", shuffleManager = badShuffleManager)
+    val message = "message"
+    val exception = intercept[SparkException] {
+      bm.putBlockDataAsStream(shuffleBlockId, StorageLevel.DISK_ONLY, ClassTag(message.getClass))
+    }
+    assert(exception.getMessage.contains("unsupported shuffle resolver"))
+  }
+
   test("test decommission block manager should not be part of peers") {
     val exec1 = "exec1"
     val exec2 = "exec2"
@@ -1846,7 +1882,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(master.getLocations(blockId).size === 2)
     assert(master.getLocations(blockId).contains(store1.blockManagerId))
 
-    store1.decommissionRddCacheBlocks()
+    val decomManager = new BlockManagerDecommissioner(conf, store1)
+    decomManager.decommissionRddCacheBlocks()
     assert(master.getLocations(blockId).size === 2)
     assert(master.getLocations(blockId).toSet === Set(store2.blockManagerId,
       store3.blockManagerId))
@@ -1866,13 +1903,57 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(master.getLocations(blockIdLarge) === Seq(store1.blockManagerId))
     assert(master.getLocations(blockIdSmall) === Seq(store1.blockManagerId))
 
-    store1.decommissionRddCacheBlocks()
+    val decomManager = new BlockManagerDecommissioner(conf, store1)
+    decomManager.decommissionRddCacheBlocks()
     // Smaller block offloaded to store2
     assert(master.getLocations(blockIdSmall) === Seq(store2.blockManagerId))
     // Larger block still present in store1 as it can't be offloaded
     assert(master.getLocations(blockIdLarge) === Seq(store1.blockManagerId))
   }
 
+  test("test migration of shuffle blocks during decommissioning") {
+    val shuffleManager1 = makeSortShuffleManager()
+    val bm1 = makeBlockManager(3500, "exec1", shuffleManager = shuffleManager1)
+    shuffleManager1.shuffleBlockResolver._blockManager = bm1
+
+    val shuffleManager2 = makeSortShuffleManager()
+    val bm2 = makeBlockManager(3500, "exec2", shuffleManager = shuffleManager2)
+    shuffleManager2.shuffleBlockResolver._blockManager = bm2
+
+    val blockSize = 5
+    val shuffleDataBlockContent = Array[Byte](0, 1, 2, 3, 4)
+    val shuffleData = ShuffleDataBlockId(0, 0, 0)
+    Files.write(bm1.diskBlockManager.getFile(shuffleData).toPath(), shuffleDataBlockContent)
+    val shuffleIndexBlockContent = Array[Byte](5, 6, 7, 8, 9)
+    val shuffleIndex = ShuffleIndexBlockId(0, 0, 0)
+    Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent)
+
+    mapOutputTracker.registerShuffle(0, 1)
+    val decomManager = new BlockManagerDecommissioner(conf, bm1)
+    try {
+      mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0))
+      assert(mapOutputTracker.shuffleStatuses(0).mapStatuses(0).location === bm1.blockManagerId)
+
+      val env = mock(classOf[SparkEnv])
+      when(env.conf).thenReturn(conf)
+      SparkEnv.set(env)
+
+      decomManager.refreshOffloadingShuffleBlocks()
+
+      eventually(timeout(1.second), interval(10.milliseconds)) {
+        assert(mapOutputTracker.shuffleStatuses(0).mapStatuses(0).location === bm2.blockManagerId)
+      }
+      assert(Files.readAllBytes(bm2.diskBlockManager.getFile(shuffleData).toPath())
+        === shuffleDataBlockContent)
+      assert(Files.readAllBytes(bm2.diskBlockManager.getFile(shuffleIndex).toPath())
+        === shuffleIndexBlockContent)
+    } finally {
+      mapOutputTracker.unregisterShuffle(0)
+      // Avoid thread leak
+      decomManager.stopOffloadingShuffleBlocks()
+    }
+  }
+
   class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
     var numCalls = 0
     var tempFileManager: DownloadFileManager = null
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala
index becf941..fd67a03 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.deploy.k8s.integrationtest
 
+import org.apache.spark.internal.config
 import org.apache.spark.internal.config.Worker
 
 private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite =>
@@ -28,18 +29,28 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite =>
       .set(Worker.WORKER_DECOMMISSION_ENABLED.key, "true")
       .set("spark.kubernetes.pyspark.pythonVersion", "3")
       .set("spark.kubernetes.container.image", pyImage)
+      .set(config.STORAGE_DECOMMISSION_ENABLED.key, "true")
+      .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED.key, "true")
+      .set(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED.key, "true")
+      // Ensure we have somewhere to migrate our data too
+      .set("spark.executor.instances", "3")
+      // The default of 30 seconds is fine, but for testing we just want to get this done fast.
+      .set("spark.storage.decommission.replicationReattemptInterval", "1")
 
     runSparkApplicationAndVerifyCompletion(
       appResource = PYSPARK_DECOMISSIONING,
       mainClass = "",
       expectedLogOnCompletion = Seq(
         "Finished waiting, stopping Spark",
-        "decommissioning executor"),
+        "decommissioning executor",
+        "Final accumulator value is: 100"),
       appArgs = Array.empty[String],
       driverPodChecker = doBasicDriverPyPodCheck,
       executorPodChecker = doBasicExecutorPyPodCheck,
       appLocator = appLocator,
       isJVM = false,
+      pyFiles = None,
+      executorPatience = None,
       decommissioningTest = true)
   }
 }
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
index 65a2f1f..ebf71e8 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
@@ -42,7 +42,8 @@ import org.apache.spark.internal.config._
 class KubernetesSuite extends SparkFunSuite
   with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite
   with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with PVTestsSuite
-  with DepsTestsSuite with DecommissionSuite with RTestsSuite with Logging with Eventually
+  // TODO(SPARK-32354): Fix and re-enable the R tests.
+  with DepsTestsSuite with DecommissionSuite /* with RTestsSuite */ with Logging with Eventually
   with Matchers {
 
 
@@ -325,21 +326,36 @@ class KubernetesSuite extends SparkFunSuite
                   val result = checkPodReady(namespace, name)
                   result shouldBe (true)
                 }
-                // Look for the string that indicates we're good to clean up
-                // on the driver
+                // Look for the string that indicates we're good to trigger decom on the driver
                 logDebug("Waiting for first collect...")
                 Eventually.eventually(TIMEOUT, INTERVAL) {
                   assert(kubernetesTestComponents.kubernetesClient
                     .pods()
                     .withName(driverPodName)
                     .getLog
-                    .contains("Waiting to give nodes time to finish."),
+                    .contains("Waiting to give nodes time to finish migration, decom exec 1."),
                     "Decommission test did not complete first collect.")
                 }
                 // Delete the pod to simulate cluster scale down/migration.
-                val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name)
+                // This will allow the pod to remain up for the grace period
+                val pod = kubernetesTestComponents.kubernetesClient.pods()
+                  .withName(name)
                 pod.delete()
                 logDebug(s"Triggered pod decom/delete: $name deleted")
+                // Look for the string that indicates we should force kill the first
+                // Executor. This simulates the pod being fully lost.
+                logDebug("Waiting for second collect...")
+                Eventually.eventually(TIMEOUT, INTERVAL) {
+                  assert(kubernetesTestComponents.kubernetesClient
+                    .pods()
+                    .withName(driverPodName)
+                    .getLog
+                    .contains("Waiting some more, please kill exec 1."),
+                    "Decommission test did not complete second collect.")
+                }
+                logDebug("Force deleting")
+                val podNoGrace = pod.withGracePeriod(0)
+                podNoGrace.delete()
               }
             case Action.DELETED | Action.ERROR =>
               execPods.remove(name)
@@ -365,9 +381,10 @@ class KubernetesSuite extends SparkFunSuite
       .get(0)
 
     driverPodChecker(driverPod)
-    // If we're testing decommissioning we delete all the executors, but we should have
-    // an executor at some point.
-    Eventually.eventually(patienceTimeout, patienceInterval) {
+
+    // If we're testing decommissioning we an executors, but we should have an executor
+    // at some point.
+    Eventually.eventually(TIMEOUT, patienceInterval) {
       execPods.values.nonEmpty should be (true)
     }
     execWatcher.close()
@@ -482,6 +499,6 @@ private[spark] object KubernetesSuite {
   val SPARK_DFS_READ_WRITE_TEST = "org.apache.spark.examples.DFSReadWriteTest"
   val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest"
   val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest"
-  val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes))
+  val TIMEOUT = PatienceConfiguration.Timeout(Span(3, Minutes))
   val INTERVAL = PatienceConfiguration.Interval(Span(1, Seconds))
 }
diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py
index f68f24d..d34e616 100644
--- a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py
+++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py
@@ -31,14 +31,29 @@ if __name__ == "__main__":
         .appName("PyMemoryTest") \
         .getOrCreate()
     sc = spark._sc
-    rdd = sc.parallelize(range(10))
-    rdd.collect()
-    print("Waiting to give nodes time to finish.")
-    time.sleep(5)
+    acc = sc.accumulator(0)
+
+    def addToAcc(x):
+        acc.add(1)
+        return x
+
+    initialRdd = sc.parallelize(range(100), 5)
+    accRdd = initialRdd.map(addToAcc)
+    # Trigger a shuffle so there are shuffle blocks to migrate
+    rdd = accRdd.map(lambda x: (x, x)).groupByKey()
     rdd.collect()
-    print("Waiting some more....")
-    time.sleep(10)
+    print("1st accumulator value is: " + str(acc.value))
+    print("Waiting to give nodes time to finish migration, decom exec 1.")
+    print("...")
+    time.sleep(30)
+    rdd.count()
+    print("Waiting some more, please kill exec 1.")
+    print("...")
+    time.sleep(30)
+    print("Executor node should be deleted now")
+    rdd.count()
     rdd.collect()
+    print("Final accumulator value is: " + str(acc.value))
     print("Finished waiting, stopping Spark.")
     spark.stop()
     print("Done, exiting Python")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 0976494..558e2c9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -91,7 +91,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
     val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]()
     blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        new LiveListenerBus(conf), None, blockManagerInfo)),
+        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org