You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/08/02 04:17:24 UTC

[spark] branch master updated: [SPARK-32923][CORE][SHUFFLE] Handle indeterminate stage retries for push-based shuffle

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c039d99  [SPARK-32923][CORE][SHUFFLE] Handle indeterminate stage retries for push-based shuffle
c039d99 is described below

commit c039d998128dd0dab27f43e7de083a71b9d1cfcf
Author: Venkata krishnan Sowrirajan <vs...@linkedin.com>
AuthorDate: Sun Aug 1 23:16:33 2021 -0500

    [SPARK-32923][CORE][SHUFFLE] Handle indeterminate stage retries for push-based shuffle
    
    ### What changes were proposed in this pull request?
    [[SPARK-23243](https://issues.apache.org/jira/browse/SPARK-23243)] and [[SPARK-25341](https://issues.apache.org/jira/browse/SPARK-25341)] addressed cases of stage retries for indeterminate stage involving operations like repartition. This PR addresses the same issues in the context of push-based shuffle. Currently there is no way to distinguish the current execution of a stage for a shuffle ID. Therefore the changes explained below are necessary.
    
    Core changes are summarized as follows:
    
    1. Introduce a new variable `shuffleMergeId` in `ShuffleDependency` which is monotonically increasing value tracking the temporal ordering of execution of <stage-id, stage-attempt-id> for a shuffle ID.
    2. Correspondingly make changes in the push-based shuffle protocol layer in `MergedShuffleFileManager`, `BlockStoreClient` passing the `shuffleMergeId` in order to keep track of the shuffle output in separate files on the shuffle service side.
    3. `DAGScheduler` increments the `shuffleMergeId` tracked in `ShuffleDependency` in the cases of a indeterministic stage execution
    4. Deterministic stage will have `shuffleMergeId` set to 0 as no special handling is needed in this case and indeterminate stage will have `shuffleMergeId` starting from 1.
    
    ### Why are the changes needed?
    
    New protocol changes are needed due to the reasons explained above.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    Added new unit tests in `RemoteBlockPushResolverSuite, DAGSchedulerSuite, BlockIdSuite, ErrorHandlerSuite`
    
    Closes #33034 from venkata91/SPARK-32923.
    
    Authored-by: Venkata krishnan Sowrirajan <vs...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../spark/network/client/TransportClient.java      |   6 +-
 .../network/protocol/MergedBlockMetaRequest.java   |  23 +-
 .../network/TransportRequestHandlerSuite.java      |   4 +-
 .../spark/network/shuffle/BlockStoreClient.java    |   6 +
 .../apache/spark/network/shuffle/ErrorHandler.java |  53 ++-
 .../network/shuffle/ExternalBlockHandler.java      |  90 +++-
 .../network/shuffle/ExternalBlockStoreClient.java  |  22 +-
 .../network/shuffle/MergedBlocksMetaListener.java  |   8 +-
 .../network/shuffle/MergedShuffleFileManager.java  |  17 +-
 .../network/shuffle/OneForOneBlockFetcher.java     | 126 ++---
 .../network/shuffle/OneForOneBlockPusher.java      |   5 +-
 .../network/shuffle/RemoteBlockPushResolver.java   | 468 ++++++++++++------
 .../shuffle/protocol/FetchShuffleBlockChunks.java  |  19 +-
 .../shuffle/protocol/FinalizeShuffleMerge.java     |  17 +-
 .../network/shuffle/protocol/MergeStatuses.java    |  18 +-
 .../network/shuffle/protocol/PushBlockStream.java  |  15 +-
 .../spark/network/shuffle/ErrorHandlerSuite.java   |  34 +-
 .../network/shuffle/ExternalBlockHandlerSuite.java |  29 +-
 .../shuffle/OneForOneBlockFetcherSuite.java        |  42 +-
 .../network/shuffle/OneForOneBlockPusherSuite.java |  69 +--
 .../shuffle/RemoteBlockPushResolverSuite.java      | 528 ++++++++++++++-------
 .../protocol/FetchShuffleBlockChunksSuite.java     |   4 +-
 .../main/scala/org/apache/spark/Dependency.scala   |  32 +-
 .../scala/org/apache/spark/MapOutputTracker.scala  |   9 +-
 .../org/apache/spark/scheduler/DAGScheduler.scala  |   5 +-
 .../org/apache/spark/scheduler/MergeStatus.scala   |  19 +-
 .../spark/shuffle/IndexShuffleBlockResolver.scala  |  25 +-
 .../apache/spark/shuffle/ShuffleBlockPusher.scala  |  22 +-
 .../spark/shuffle/ShuffleBlockResolver.scala       |  10 +-
 .../scala/org/apache/spark/storage/BlockId.scala   |  77 ++-
 .../org/apache/spark/storage/BlockManager.scala    |   4 +-
 .../spark/storage/PushBasedFetchHelper.scala       |  41 +-
 .../storage/ShuffleBlockFetcherIterator.scala      |  35 +-
 .../org/apache/spark/MapOutputTrackerSuite.scala   |  22 +-
 .../apache/spark/scheduler/DAGSchedulerSuite.scala |  78 ++-
 .../spark/shuffle/ShuffleBlockPusherSuite.scala    |  15 +-
 .../sort/IndexShuffleBlockResolverSuite.scala      |  18 +-
 .../org/apache/spark/storage/BlockIdSuite.scala    |  47 +-
 .../storage/ShuffleBlockFetcherIteratorSuite.scala | 161 ++++---
 39 files changed, 1501 insertions(+), 722 deletions(-)

diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
index a50c04c..dd2fdb0 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -206,12 +206,15 @@ public class TransportClient implements Closeable {
    *
    * @param appId applicationId.
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param callback callback the handle the reply.
    */
   public void sendMergedBlockMetaReq(
       String appId,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId,
       MergedBlockMetaResponseCallback callback) {
     long requestId = requestId();
@@ -222,7 +225,8 @@ public class TransportClient implements Closeable {
     handler.addRpcRequest(requestId, callback);
     RpcChannelListener listener = new RpcChannelListener(requestId, callback);
     channel.writeAndFlush(
-      new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener);
+      new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId,
+        reduceId)).addListener(listener);
   }
 
   /**
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
index cf7c22d..c85d104 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
@@ -32,13 +32,20 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
   public final long requestId;
   public final String appId;
   public final int shuffleId;
+  public final int shuffleMergeId;
   public final int reduceId;
 
-  public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
+  public MergedBlockMetaRequest(
+      long requestId,
+      String appId,
+      int shuffleId,
+      int shuffleMergeId,
+      int reduceId) {
     super(null, false);
     this.requestId = requestId;
     this.appId = appId;
     this.shuffleId = shuffleId;
+    this.shuffleMergeId = shuffleMergeId;
     this.reduceId = reduceId;
   }
 
@@ -49,7 +56,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
 
   @Override
   public int encodedLength() {
-    return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4;
+    return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4 + 4;
   }
 
   @Override
@@ -57,6 +64,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
     buf.writeLong(requestId);
     Encoders.Strings.encode(buf, appId);
     buf.writeInt(shuffleId);
+    buf.writeInt(shuffleMergeId);
     buf.writeInt(reduceId);
   }
 
@@ -64,21 +72,23 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
     long requestId = buf.readLong();
     String appId = Encoders.Strings.decode(buf);
     int shuffleId = buf.readInt();
+    int shuffleMergeId = buf.readInt();
     int reduceId = buf.readInt();
-    return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
+    return new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId, reduceId);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(requestId, appId, shuffleId, reduceId);
+    return Objects.hashCode(requestId, appId, shuffleId, shuffleMergeId, reduceId);
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof MergedBlockMetaRequest) {
       MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
-      return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId
-        && Objects.equal(appId, o.appId);
+      return requestId == o.requestId && shuffleId == o.shuffleId &&
+        shuffleMergeId == o.shuffleMergeId && reduceId == o.reduceId &&
+        Objects.equal(appId, o.appId);
     }
     return false;
   }
@@ -89,6 +99,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
       .append("requestId", requestId)
       .append("appId", appId)
       .append("shuffleId", shuffleId)
+      .append("shuffleMergeId", shuffleMergeId)
       .append("reduceId", reduceId)
       .toString();
   }
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
index b3befb8..70c7a16 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
@@ -152,14 +152,14 @@ public class TransportRequestHandlerSuite {
     TransportClient reverseClient = mock(TransportClient.class);
     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
       rpcHandler, 2L, null);
-    MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0);
+    MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0, 0);
     requestHandler.handle(validMetaReq);
     assertEquals(1, responseAndPromisePairs.size());
     assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess);
     assertEquals(2,
       ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks());
 
-    MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1);
+    MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 0, 1);
     requestHandler.handle(invalidMetaReq);
     assertEquals(2, responseAndPromisePairs.size());
     assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure);
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
index b685213..8298846 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
@@ -167,6 +167,8 @@ public abstract class BlockStoreClient implements Closeable {
    * @param host host of shuffle server
    * @param port port of shuffle server.
    * @param shuffleId shuffle ID of the shuffle to be finalized
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param listener the listener to receive MergeStatuses
    *
    * @since 3.1.0
@@ -175,6 +177,7 @@ public abstract class BlockStoreClient implements Closeable {
       String host,
       int port,
       int shuffleId,
+      int shuffleMergeId,
       MergeFinalizerListener listener) {
     throw new UnsupportedOperationException();
   }
@@ -185,6 +188,8 @@ public abstract class BlockStoreClient implements Closeable {
    * @param host the host of the remote node.
    * @param port the port of the remote node.
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param listener the listener to receive chunk counts.
    *
@@ -194,6 +199,7 @@ public abstract class BlockStoreClient implements Closeable {
       String host,
       int port,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId,
       MergedBlocksMetaListener listener) {
     throw new UnsupportedOperationException();
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
index a758875..0149ad7 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
@@ -55,12 +55,14 @@ public interface ErrorHandler {
   class BlockPushErrorHandler implements ErrorHandler {
     /**
      * String constant used for generating exception messages indicating a block to be merged
-     * arrives too late on the server side, and also for later checking such exceptions on the
-     * client side. When we get a block push failure because of the block arrives too late, we
-     * will not retry pushing the block nor log the exception on the client side.
+     * arrives too late or stale block push in the case of indeterminate stage retries on the
+     * server side, and also for later checking such exceptions on the client side. When we get
+     * a block push failure because of the block push being stale or arrives too late, we will
+     * not retry pushing the block nor log the exception on the client side.
      */
-    public static final String TOO_LATE_MESSAGE_SUFFIX =
-      "received after merged shuffle is finalized";
+    public static final String TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX =
+      "received after merged shuffle is finalized or stale block push as shuffle blocks of a"
+        + " higher shuffleMergeId for the shuffle is being pushed";
 
     /**
      * String constant used for generating exception messages indicating the server couldn't
@@ -81,25 +83,54 @@ public interface ErrorHandler {
     public static final String IOEXCEPTIONS_EXCEEDED_THRESHOLD_PREFIX =
       "IOExceptions exceeded the threshold";
 
+    /**
+     * String constant used for generating exception messages indicating the server rejecting a
+     * shuffle finalize request since shuffle blocks of a higher shuffleMergeId for a shuffle is
+     * already being pushed. This typically happens in the case of indeterminate stage retries
+     * where if a stage attempt fails then the entirety of the shuffle output needs to be rolled
+     * back. For more details refer SPARK-23243, SPARK-25341 and SPARK-32923.
+     */
+    public static final String STALE_SHUFFLE_FINALIZE_SUFFIX =
+      "stale shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the"
+        + " shuffle is already being pushed";
+
     @Override
     public boolean shouldRetryError(Throwable t) {
       // If it is a connection time-out or a connection closed exception, no need to retry.
       // If it is a FileNotFoundException originating from the client while pushing the shuffle
-      // blocks to the server, even then there is no need to retry. We will still log this exception
-      // once which helps with debugging.
+      // blocks to the server, even then there is no need to retry. We will still log this
+      // exception once which helps with debugging.
       if (t.getCause() != null && (t.getCause() instanceof ConnectException ||
           t.getCause() instanceof FileNotFoundException)) {
         return false;
       }
-      // If the block is too late, there is no need to retry it
-      return !Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
+
+      String errorStackTrace = Throwables.getStackTraceAsString(t);
+      // If the block is too late or stale block push, there is no need to retry it
+      return !errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX);
     }
 
     @Override
     public boolean shouldLogError(Throwable t) {
       String errorStackTrace = Throwables.getStackTraceAsString(t);
-      return !errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
-        !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
+      return !(errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) ||
+        errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
+    }
+  }
+
+  class BlockFetchErrorHandler implements ErrorHandler {
+    public static final String STALE_SHUFFLE_BLOCK_FETCH =
+      "stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for the"
+        + " shuffle is available";
+
+    @Override
+    public boolean shouldRetryError(Throwable t) {
+      return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH);
+    }
+
+    @Override
+    public boolean shouldLogError(Throwable t) {
+      return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH);
     }
   }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
index e0f2e95..cfabcd5 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
@@ -218,7 +218,8 @@ public class ExternalBlockHandler extends RpcHandler
         callback.onSuccess(statuses.toByteBuffer());
       } catch(IOException e) {
         throw new RuntimeException(String.format("Error while finalizing shuffle merge "
-          + "for application %s shuffle %d", msg.appId, msg.shuffleId), e);
+          + "for application %s shuffle %d with shuffleMergeId %d", msg.appId, msg.shuffleId,
+            msg.shuffleMergeId), e);
       } finally {
         responseDelayContext.stop();
       }
@@ -237,7 +238,7 @@ public class ExternalBlockHandler extends RpcHandler
       checkAuth(client, metaRequest.appId);
       MergedBlockMeta mergedMeta =
         mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId,
-          metaRequest.reduceId);
+          metaRequest.shuffleMergeId, metaRequest.reduceId);
       logger.debug(
         "Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ",
           metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId,
@@ -364,7 +365,6 @@ public class ExternalBlockHandler extends RpcHandler
     private int index = 0;
     private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
     private final int size;
-    private boolean requestForMergedBlockChunks;
 
     ManagedBufferIterator(OpenBlocks msg) {
       String appId = msg.appId;
@@ -377,13 +377,14 @@ public class ExternalBlockHandler extends RpcHandler
         size = mapIdAndReduceIds.length;
         blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
           mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
-      } else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
-        requestForMergedBlockChunks = true;
+      } else if (blockId0Parts.length == 5 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
         final int shuffleId = Integer.parseInt(blockId0Parts[1]);
-        final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
+        final int shuffleMergeId = Integer.parseInt(blockId0Parts[2]);
+        final int[] reduceIdAndChunkIds = shuffleReduceIdAndChunkIds(blockIds, shuffleId,
+          shuffleMergeId);
         size = reduceIdAndChunkIds.length;
         blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId,
-          reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
+          shuffleMergeId, reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
       } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
         final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
         size = rddAndSplitIds.length;
@@ -407,27 +408,64 @@ public class ExternalBlockHandler extends RpcHandler
       return rddAndSplitIds;
     }
 
+    /**
+     * @param blockIds Regular shuffle blockIds starts with SHUFFLE_BLOCK_ID to be parsed
+     * @param shuffleId shuffle blocks shuffleId
+     * @return mapId and reduceIds of the shuffle blocks in the same order as that of the blockIds
+     *
+     * Regular shuffle blocks format should be shuffle_$shuffleId_$mapId_$reduceId
+     */
     private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
-      // For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds.
-      // For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds.
-      final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length];
+      final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4
-          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID))
-          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) {
+        if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
-        // For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId.
-        primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]);
-        // For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId.
-        primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
+        // mapId
+        mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // reduceId
+        mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
       }
-      return primaryIdAndSecondaryIds;
+      return mapIdAndReduceIds;
+    }
+
+    /**
+     * @param blockIds Shuffle merged chunks starts with SHUFFLE_CHUNK_ID to be parsed
+     * @param shuffleId shuffle blocks shuffleId
+     * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+     *                       of shuffle by an indeterminate stage attempt.
+     * @return reduceId and chunkIds of the shuffle chunks in the same order as that of the
+     *         blockIds
+     *
+     * Shuffle merged chunks format should be
+     * shuffleChunk_$shuffleId_$shuffleMergeId_$reduceId_$chunkId
+     */
+    private int[] shuffleReduceIdAndChunkIds(
+        String[] blockIds,
+        int shuffleId,
+        int shuffleMergeId) {
+      final int[] reduceIdAndChunkIds = new int[2 * blockIds.length];
+      for(int i = 0; i < blockIds.length; i++) {
+        String[] blockIdParts = blockIds[i].split("_");
+        if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_CHUNK_ID)) {
+          throw new IllegalArgumentException("Unexpected shuffle chunk id format: " + blockIds[i]);
+        }
+        if (Integer.parseInt(blockIdParts[1]) != shuffleId ||
+            Integer.parseInt(blockIdParts[2]) != shuffleMergeId) {
+          throw new IllegalArgumentException(String.format("Expected shuffleId = %s"
+            + " and shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockIds[i]));
+        }
+        // reduceId
+        reduceIdAndChunkIds[2 * i] = Integer.parseInt(blockIdParts[3]);
+        // chunkId
+        reduceIdAndChunkIds[2 * i + 1] = Integer.parseInt(blockIdParts[4]);
+      }
+      return reduceIdAndChunkIds;
     }
 
     @Override
@@ -511,12 +549,14 @@ public class ExternalBlockHandler extends RpcHandler
 
     private final String appId;
     private final int shuffleId;
+    private final int shuffleMergeId;
     private final int[] reduceIds;
     private final int[][] chunkIds;
 
     ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
       appId = msg.appId;
       shuffleId = msg.shuffleId;
+      shuffleMergeId = msg.shuffleMergeId;
       reduceIds = msg.reduceIds;
       chunkIds = msg.chunkIds;
       // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
@@ -533,7 +573,7 @@ public class ExternalBlockHandler extends RpcHandler
     @Override
     public ManagedBuffer next() {
       ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData(
-        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
+        appId, shuffleId, shuffleMergeId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
       if (chunkIdx < chunkIds[reduceIdx].length - 1) {
         chunkIdx += 1;
       } else {
@@ -580,12 +620,20 @@ public class ExternalBlockHandler extends RpcHandler
 
     @Override
     public ManagedBuffer getMergedBlockData(
-        String appId, int shuffleId, int reduceId, int chunkId) {
+        String appId,
+        int shuffleId,
+        int shuffleMergeId,
+        int reduceId,
+        int chunkId) {
       throw new UnsupportedOperationException("Cannot handle shuffle block merge");
     }
 
     @Override
-    public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
+    public MergedBlockMeta getMergedBlockMeta(
+        String appId,
+        int shuffleId,
+        int shuffleMergeId,
+        int reduceId) {
       throw new UnsupportedOperationException("Cannot handle shuffle block merge");
     }
 
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
index f88915b..eb2d118 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
@@ -172,12 +172,14 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
       String host,
       int port,
       int shuffleId,
+      int shuffleMergeId,
       MergeFinalizerListener listener) {
     checkInit();
     try {
       TransportClient client = clientFactory.createClient(host, port);
       ByteBuffer finalizeShuffleMerge =
-        new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId).toByteBuffer();
+        new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId,
+          shuffleMergeId).toByteBuffer();
       client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() {
         @Override
         public void onSuccess(ByteBuffer response) {
@@ -202,29 +204,31 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
       String host,
       int port,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId,
       MergedBlocksMetaListener listener) {
     checkInit();
-    logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port,
-      shuffleId, reduceId);
+    logger.debug("Get merged blocks meta from {}:{} for shuffleId {} shuffleMergeId {}"
+      + " reduceId {}", host, port, shuffleId, shuffleMergeId, reduceId);
     try {
       TransportClient client = clientFactory.createClient(host, port);
-      client.sendMergedBlockMetaReq(appId, shuffleId, reduceId,
+      client.sendMergedBlockMetaReq(appId, shuffleId, shuffleMergeId, reduceId,
         new MergedBlockMetaResponseCallback() {
           @Override
           public void onSuccess(int numChunks, ManagedBuffer buffer) {
-            logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}",
-              shuffleId, reduceId);
-            listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer));
+            logger.trace("Successfully got merged block meta for shuffleId {} shuffleMergeId {}"
+              + " reduceId {}", shuffleId, shuffleMergeId, reduceId);
+            listener.onSuccess(shuffleId, reduceId, shuffleMergeId,
+              new MergedBlockMeta(numChunks, buffer));
           }
 
           @Override
           public void onFailure(Throwable e) {
-            listener.onFailure(shuffleId, reduceId, e);
+            listener.onFailure(shuffleId, shuffleMergeId, reduceId, e);
           }
         });
     } catch (Exception e) {
-      listener.onFailure(shuffleId, reduceId, e);
+      listener.onFailure(shuffleId, shuffleMergeId, reduceId, e);
     }
   }
 
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java
index 0e277d3..cea76dd 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java
@@ -30,17 +30,21 @@ public interface MergedBlocksMetaListener extends EventListener {
    * Called after successfully receiving the meta of a merged block.
    *
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param meta contains meta information of a merged block.
    */
-  void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta);
+  void onSuccess(int shuffleId, int shuffleMergeId, int reduceId, MergedBlockMeta meta);
 
   /**
    * Called when there is an exception while fetching the meta of a merged block.
    *
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param exception exception getting chunk counts.
    */
-  void onFailure(int shuffleId, int reduceId, Throwable exception);
+  void onFailure(int shuffleId, int shuffleMergeId, int reduceId, Throwable exception);
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java
index 4ce6a47..630386d 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java
@@ -85,21 +85,34 @@ public interface MergedShuffleFileManager {
    *
    * @param appId application ID
    * @param shuffleId shuffle ID
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reducer ID
    * @param chunkId merged shuffle file chunk ID
    * @return The {@link ManagedBuffer} for the given merged shuffle chunk
    */
-  ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId);
+  ManagedBuffer getMergedBlockData(
+      String appId,
+      int shuffleId,
+      int shuffleMergeId,
+      int reduceId,
+      int chunkId);
 
   /**
    * Get the meta information of a merged block.
    *
    * @param appId application ID
    * @param shuffleId shuffle ID
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reducer ID
    * @return meta information of a merged block
    */
-  MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId);
+  MergedBlockMeta getMergedBlockMeta(
+      String appId,
+      int shuffleId,
+      int shuffleMergeId,
+      int reduceId);
 
   /**
    * Get the local directories which stores the merged shuffle files.
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index dd1a715..a788b50 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.LinkedHashMap;
-import java.util.Set;
+import java.util.Map;
 
 import com.google.common.primitives.Ints;
 import com.google.common.primitives.Longs;
@@ -56,6 +56,8 @@ public class OneForOneBlockFetcher {
   private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
   private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_";
   private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_";
+  private static final String SHUFFLE_BLOCK_SPLIT = "shuffle";
+  private static final String SHUFFLE_CHUNK_SPLIT = "shuffleChunk";
 
   private final TransportClient client;
   private final BlockTransferMessage message;
@@ -125,63 +127,87 @@ public class OneForOneBlockFetcher {
       String execId,
       String[] blockIds) {
     if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
-      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+      return createFetchShuffleChunksMsg(appId, execId, blockIds);
     } else {
-      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+      return createFetchShuffleBlocksMsg(appId, execId, blockIds);
     }
   }
 
-  /**
-   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
-   * analyzing the passed in blockIds.
-   */
-  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksMsg(
       String appId,
       String execId,
-      String[] blockIds,
-      boolean areMergedChunks) {
+      String[] blockIds) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
-
-    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
-    // is reduceId.
-    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
+    Map<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
-        throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
-          ", got:" + blockId);
-      }
-      Number primaryId;
-      if (!areMergedChunks) {
-        primaryId = Long.parseLong(blockIdParts[2]);
-      } else {
-        primaryId = Integer.parseInt(blockIdParts[2]);
+        throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId);
       }
-      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId,
+
+      long mapId = Long.parseLong(blockIdParts[2]);
+      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.computeIfAbsent(mapId,
         id -> new BlocksInfo());
-      blocksInfoByPrimaryId.blockIds.add(blockId);
-      // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a
-      // shuffleChunk block, then blockIdParts[3] = chunkId
-      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
+      blocksInfoByMapId.blockIds.add(blockId);
+      blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[3]));
+
       if (batchFetchEnabled) {
-        // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block.
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
         // blockIdParts[4] is the end reduce id for the batch range
-        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
+
+    int[][] reduceIdsArray = getSecondaryIds(mapIdToBlocksInfo);
+    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
+    return new FetchShuffleBlocks(
+      appId, execId, shuffleId, mapIds, reduceIdsArray, batchFetchEnabled);
+  }
+
+  private AbstractFetchShuffleBlocks createFetchShuffleChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    String[] firstBlock = splitBlockId(blockIds[0]);
+    int shuffleId = Integer.parseInt(firstBlock[1]);
+    int shuffleMergeId = Integer.parseInt(firstBlock[2]);
+
+    Map<Integer, BlocksInfo> reduceIdToBlocksInfo = new LinkedHashMap<>();
+    for (String blockId : blockIds) {
+      String[] blockIdParts = splitBlockId(blockId);
+      if (Integer.parseInt(blockIdParts[1]) != shuffleId ||
+          Integer.parseInt(blockIdParts[2]) != shuffleMergeId) {
+        throw new IllegalArgumentException(String.format("Expected shuffleId = %s and"
+          + " shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockId));
+      }
+
+      int reduceId = Integer.parseInt(blockIdParts[3]);
+      BlocksInfo blocksInfoByReduceId = reduceIdToBlocksInfo.computeIfAbsent(reduceId,
+        id -> new BlocksInfo());
+      blocksInfoByReduceId.blockIds.add(blockId);
+      blocksInfoByReduceId.ids.add(Integer.parseInt(blockIdParts[4]));
+    }
+
+    int[][] chunkIdsArray = getSecondaryIds(reduceIdToBlocksInfo);
+    int[] reduceIds = Ints.toArray(reduceIdToBlocksInfo.keySet());
+
+    return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds,
+      chunkIdsArray);
+  }
+
+  private int[][] getSecondaryIds(Map<? extends Number, BlocksInfo> primaryIdsToBlockInfo) {
     // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
     // secondaryIds are chunkIds.
-    int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
+    int[][] secondaryIds = new int[primaryIdsToBlockInfo.size()][];
     int blockIdIndex = 0;
     int secIndex = 0;
-    for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) {
-      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids);
+    for (BlocksInfo blocksInfo: primaryIdsToBlockInfo.values()) {
+      secondaryIds[secIndex++] = Ints.toArray(blocksInfo.ids);
 
       // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
       // FetchShuffleBlockChunks because the shuffle data's return order should match the
@@ -191,35 +217,27 @@ public class OneForOneBlockFetcher {
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
-    if (!areMergedChunks) {
-      long[] mapIds = Longs.toArray(primaryIds);
-      return new FetchShuffleBlocks(
-        appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled);
-    } else {
-      int[] reduceIds = Ints.toArray(primaryIds);
-      return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray);
-    }
+    return secondaryIds;
   }
 
-  /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
+  /**
+   * Split the blockId and return accordingly
+   * shuffleChunk - return shuffleId, shuffleMergeId, reduceId and chunkIds
+   * shuffle block - return shuffleId, mapId, reduceId
+   * shuffle batch block - return shuffleId, mapId, begin reduceId and end reduceId
+   */
   private String[] splitBlockId(String blockId) {
     String[] blockIdParts = blockId.split("_");
-    // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
-    // For single block id, the format contains shuffleId, mapId, educeId.
-    // For single block chunk id, the format contains shuffleId, reduceId, chunkId.
     if (blockIdParts.length < 4 || blockIdParts.length > 5) {
-      throw new IllegalArgumentException(
-        "Unexpected shuffle block id format: " + blockId);
+      throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
     }
-    if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) {
-      throw new IllegalArgumentException(
-        "Unexpected shuffle block id format: " + blockId);
+    if (blockIdParts.length == 4 && !blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT)) {
+      throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
     }
-    if (blockIdParts.length == 4 &&
-      !(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) {
-      throw new IllegalArgumentException(
-        "Unexpected shuffle block id format: " + blockId);
+    if (blockIdParts.length == 5 &&
+      !(blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT) ||
+        blockIdParts[0].equals(SHUFFLE_CHUNK_SPLIT))) {
+      throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
     }
     return blockIdParts;
   }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java
index 0e1c59f..f9d313c 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java
@@ -135,13 +135,14 @@ public class OneForOneBlockPusher {
       assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block "
         + blockIds[i];
       String[] blockIdParts = blockIds[i].split("_");
-      if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) {
+      if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) {
         throw new IllegalArgumentException(
           "Unexpected shuffle push block id format: " + blockIds[i]);
       }
       ByteBuffer header =
         new PushBlockStream(appId, appAttemptId, Integer.parseInt(blockIdParts[1]),
-          Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]) , i).toByteBuffer();
+          Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]),
+            Integer.parseInt(blockIdParts[4]), i).toByteBuffer();
       client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]),
         new BlockPushCallback(i, blockIds[i]));
     }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
index f88cfee..cc7d4db 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
@@ -28,6 +28,7 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentMap;
@@ -78,6 +79,15 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   public static final String MERGE_DIR_KEY = "mergeDir";
   public static final String ATTEMPT_ID_KEY = "attemptId";
   private static final int UNDEFINED_ATTEMPT_ID = -1;
+  // Shuffles of determinate stages will have shuffleMergeId set to 0
+  private static final int DETERMINATE_SHUFFLE_MERGE_ID = 0;
+
+  // ConcurrentHashMap doesn't allow null for keys or values which is why this is required.
+  // Marker to identify finalized indeterminate shuffle partitions in the case of indeterminate
+  // stage retries.
+  @VisibleForTesting
+  public static final Map<Integer, AppShufflePartitionInfo> INDETERMINATE_SHUFFLE_FINALIZED =
+    Collections.emptyMap();
 
   /**
    * A concurrent hashmap where the key is the applicationId, and the value includes
@@ -128,50 +138,79 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   }
 
   /**
-   * Given the appShuffleInfo, shuffleId and reduceId that uniquely identifies a given shuffle
-   * partition of an application, retrieves the associated metadata. If not present and the
-   * corresponding merged shuffle does not exist, initializes the metadata.
+   * Given the appShuffleInfo, shuffleId, shuffleMergeId and reduceId that uniquely identifies
+   * a given shuffle partition of an application, retrieves the associated metadata. If not
+   * present and the corresponding merged shuffle does not exist, initializes the metadata.
    */
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
-      int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
-      appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
+      int shuffleMergeId,
+      int reduceId) throws StaleBlockPushException {
+    ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> shuffles = appShuffleInfo.shuffles;
+    AppShuffleMergePartitionsInfo shufflePartitionsWithMergeId =
+      shuffles.compute(shuffleId, (id, appShuffleMergePartitionsInfo) -> {
+        if (appShuffleMergePartitionsInfo == null) {
+          File dataFile =
+            appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
           // If this partition is already finalized then the partitions map will not contain the
-          // shuffleId but the data file would exist. In that case the block is considered late.
+          // shuffleId for determinate stages but the data file would exist.
+          // In that case the block is considered late. In the case of indeterminate stages, most
+          // recent shuffleMergeId finalized would be pointing to INDETERMINATE_SHUFFLE_FINALIZED
           if (dataFile.exists()) {
             return null;
+          } else {
+            logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}"
+              + " with shuffleMergeId {} for application {}_{}", shuffleId, shuffleMergeId,
+              appShuffleInfo.appId, appShuffleInfo.attemptId);
+            return new AppShuffleMergePartitionsInfo(shuffleMergeId, false);
           }
-          return new ConcurrentHashMap<>();
         } else {
-          return map;
+          // Reject the request as we have already seen a higher shuffleMergeId than the
+          // current incoming one
+          int latestShuffleMergeId = appShuffleMergePartitionsInfo.shuffleMergeId;
+          if (latestShuffleMergeId > shuffleMergeId) {
+            throw new StaleBlockPushException(String.format("Rejecting shuffle blocks push request"
+              + " for shuffle %s with shuffleMergeId %s for application %s_%s as a higher"
+              + " shuffleMergeId %s request is already seen", shuffleId, shuffleMergeId,
+              appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId));
+          } else if (latestShuffleMergeId == shuffleMergeId) {
+            return appShuffleMergePartitionsInfo;
+          } else {
+            // Higher shuffleMergeId seen for the shuffle ID meaning new stage attempt is being
+            // run for the shuffle ID. Close and clean up old shuffleMergeId files,
+            // happens in the indeterminate stage retries
+            logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}"
+              + " with shuffleMergeId {} for application {}_{} since it is higher than the"
+              + " latest shuffleMergeId {} already seen", shuffleId, shuffleMergeId,
+              appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId);
+            mergedShuffleCleaner.execute(() ->
+              closeAndDeletePartitionFiles(appShuffleMergePartitionsInfo.shuffleMergePartitions));
+            return new AppShuffleMergePartitionsInfo(shuffleMergeId, false);
+          }
         }
       });
-    if (shufflePartitions == null) {
+
+    // It only gets here when the shuffle is already finalized.
+    if (null == shufflePartitionsWithMergeId ||
+        INDETERMINATE_SHUFFLE_FINALIZED == shufflePartitionsWithMergeId.shuffleMergePartitions) {
       return null;
     }
 
-    return shufflePartitions.computeIfAbsent(reduceId, key -> {
-      // It only gets here when the key is not present in the map. This could either
-      // be the first time the merge manager receives a pushed block for a given application
-      // shuffle partition, or after the merged shuffle file is finalized. We handle these
-      // two cases accordingly by checking if the file already exists.
+    Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions =
+      shufflePartitionsWithMergeId.shuffleMergePartitions;
+    return shuffleMergePartitions.computeIfAbsent(reduceId, key -> {
+      // It only gets here when the key is not present in the map. The first time the merge
+      // manager receives a pushed block for a given application shuffle partition.
+      File dataFile =
+        appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
       File indexFile =
-        appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
+        appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
       File metaFile =
-        appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId);
+        appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId);
       try {
-        if (dataFile.exists()) {
-          return null;
-        } else {
-          return newAppShufflePartitionInfo(
-            appShuffleInfo.appId, shuffleId, reduceId, dataFile, indexFile, metaFile);
-        }
+        return newAppShufflePartitionInfo(appShuffleInfo.appId, shuffleId, shuffleMergeId,
+          reduceId, dataFile, indexFile, metaFile);
       } catch (IOException e) {
         logger.error(
           "Cannot create merged shuffle partition with data file {}, index file {}, and "
@@ -179,7 +218,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
             indexFile.getAbsolutePath(), metaFile.getAbsolutePath());
         throw new RuntimeException(
           String.format("Cannot initialize merged shuffle partition for appId %s shuffleId %s "
-            + "reduceId %s", appShuffleInfo.appId, shuffleId, reduceId), e);
+            + "shuffleMergeId %s reduceId %s", appShuffleInfo.appId, shuffleId, shuffleMergeId,
+              reduceId), e);
       }
     });
   }
@@ -188,19 +228,31 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   AppShufflePartitionInfo newAppShufflePartitionInfo(
       String appId,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId,
       File dataFile,
       File indexFile,
       File metaFile) throws IOException {
-    return new AppShufflePartitionInfo(appId, shuffleId, reduceId, dataFile,
+    return new AppShufflePartitionInfo(appId, shuffleId, shuffleMergeId, reduceId, dataFile,
       new MergeShuffleFile(indexFile), new MergeShuffleFile(metaFile));
   }
 
   @Override
-  public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
+  public MergedBlockMeta getMergedBlockMeta(
+      String appId,
+      int shuffleId,
+      int shuffleMergeId,
+      int reduceId) {
     AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId);
+    AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId);
+    if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) {
+      throw new RuntimeException(String.format(
+        "MergedBlockMeta fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s",
+        shuffleId, shuffleMergeId, reduceId,
+        ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH));
+    }
     File indexFile =
-      appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
+      appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
     if (!indexFile.exists()) {
       throw new RuntimeException(String.format(
         "Merged shuffle index file %s not found", indexFile.getPath()));
@@ -208,7 +260,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     int size = (int) indexFile.length();
     // First entry is the zero offset
     int numChunks = (size / Long.BYTES) - 1;
-    File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId);
+    File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId);
     if (!metaFile.exists()) {
       throw new RuntimeException(String.format("Merged shuffle meta file %s not found",
         metaFile.getPath()));
@@ -216,21 +268,30 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     FileSegmentManagedBuffer chunkBitMaps =
       new FileSegmentManagedBuffer(conf, metaFile, 0L, metaFile.length());
     logger.trace(
-      "{} shuffleId {} reduceId {} num chunks {}", appId, shuffleId, reduceId, numChunks);
+      "{} shuffleId {} shuffleMergeId {} reduceId {} num chunks {}",
+      appId, shuffleId, shuffleMergeId, reduceId, numChunks);
     return new MergedBlockMeta(numChunks, chunkBitMaps);
   }
 
   @SuppressWarnings("UnstableApiUsage")
   @Override
-  public ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId) {
+  public ManagedBuffer getMergedBlockData(
+      String appId, int shuffleId, int shuffleMergeId, int reduceId, int chunkId) {
     AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId);
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId);
+    AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId);
+    if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) {
+      throw new RuntimeException(String.format(
+        "MergedBlockData fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s",
+        shuffleId, shuffleMergeId, reduceId,
+        ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH));
+    }
+    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
     if (!dataFile.exists()) {
       throw new RuntimeException(String.format("Merged shuffle data file %s not found",
         dataFile.getPath()));
     }
     File indexFile =
-      appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
+      appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
     try {
       // If we get here, the merged shuffle file should have been properly finalized. Thus we can
       // use the file length to determine the size of the merged shuffle block.
@@ -270,19 +331,33 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   void closeAndDeletePartitionFilesIfNeeded(
       AppShuffleInfo appShuffleInfo,
       boolean cleanupLocalDirs) {
-    for (Map<Integer, AppShufflePartitionInfo> partitionMap : appShuffleInfo.partitions.values()) {
-      for (AppShufflePartitionInfo partitionInfo : partitionMap.values()) {
+    appShuffleInfo.shuffles.forEach((shuffleId, shuffleInfo) -> shuffleInfo.shuffleMergePartitions
+      .forEach((shuffleMergeId, partitionInfo) -> {
         synchronized (partitionInfo) {
-          partitionInfo.closeAllFiles();
+          partitionInfo.closeAllFilesAndDeleteIfNeeded(false);
         }
-      }
-    }
+      }));
     if (cleanupLocalDirs) {
       deleteExecutorDirs(appShuffleInfo);
     }
   }
 
   /**
+   * Clean up all the AppShufflePartitionInfo for a specific shuffleMergeId. This is done
+   * since there is a higher shuffleMergeId request made for a shuffleId, therefore clean
+   * up older shuffleMergeId partitions. The cleanup will be executed in a separate thread.
+   */
+  @VisibleForTesting
+  void closeAndDeletePartitionFiles(Map<Integer, AppShufflePartitionInfo> partitions) {
+    partitions
+      .forEach((partitionId, partitionInfo) -> {
+        synchronized (partitionInfo) {
+          partitionInfo.closeAllFilesAndDeleteIfNeeded(true);
+        }
+      });
+  }
+
+  /**
    * Serially delete local dirs.
    */
   @VisibleForTesting
@@ -304,9 +379,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   @Override
   public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) {
     AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId);
-    final String streamId = String.format("%s_%d_%d_%d",
-      OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.mapIndex,
-      msg.reduceId);
+    final String streamId = String.format("%s_%d_%d_%d_%d",
+      OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.shuffleMergeId,
+      msg.mapIndex, msg.reduceId);
     if (appShuffleInfo.attemptId != msg.appAttemptId) {
       // If this Block belongs to a former application attempt, it is considered late,
       // as only the blocks from the current application attempt will be merged
@@ -317,12 +392,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
           msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
     }
     // Retrieve merged shuffle file metadata
-    AppShufflePartitionInfo partitionInfoBeforeCheck =
-      getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId, msg.reduceId);
-    // Here partitionInfo will be null in 2 cases:
+    AppShufflePartitionInfo partitionInfoBeforeCheck;
+    try {
+      partitionInfoBeforeCheck = getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId,
+        msg.shuffleMergeId, msg.reduceId);
+    } catch(StaleBlockPushException sbp) {
+      // Set partitionInfoBeforeCheck to null so that stale block push gets handled.
+      partitionInfoBeforeCheck = null;
+    }
+    // Here partitionInfo will be null in 3 cases:
     // 1) The request is received for a block that has already been merged, this is possible due
     // to the retry logic.
     // 2) The request is received after the merged shuffle is finalized, thus is too late.
+    // 3) The request is received for a older shuffleMergeId, therefore the block push is rejected.
     //
     // For case 1, we will drain the data in the channel and just respond success
     // to the client. This is required because the response of the previously merged
@@ -345,6 +427,13 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     // to notify the client of the failure, so that it can properly halt pushing the remaining
     // blocks upon receiving such failures to preserve resources on the server/client side.
     //
+    // For case 3, we will also drain the data in the channel, but throw an exception in
+    // {@link org.apache.spark.network.client.StreamCallback#onComplete(String)}. This way,
+    // the client will be notified of the failure but the channel will remain active. It is
+    // important to notify the client of the failure, so that it can properly halt pushing the
+    // remaining blocks upon receiving such failures to preserve resources on the server/client
+    // side.
+    //
     // Speculative execution would also raise a possible scenario with duplicate blocks. Although
     // speculative execution would kill the slower task attempt, leading to only 1 task attempt
     // succeeding in the end, there is no guarantee that only one copy of the block will be
@@ -353,18 +442,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     // getting killed. When this happens, we need to distinguish the duplicate blocks as they
     // arrive. More details on this is explained in later comments.
 
-    // Track if the block is received after shuffle merge finalize
-    final boolean isTooLate = partitionInfoBeforeCheck == null;
-    // Check if the given block is already merged by checking the bitmap against the given map index
-    final AppShufflePartitionInfo partitionInfo = partitionInfoBeforeCheck != null
-      && partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null
-        : partitionInfoBeforeCheck;
+    // Track if the block is received after shuffle merge finalized or from an older
+    // shuffleMergeId attempt.
+    final boolean isStaleBlockOrTooLate = partitionInfoBeforeCheck == null;
+    // Check if the given block is already merged by checking the bitmap against the given map
+    // index
+    final AppShufflePartitionInfo partitionInfo = isStaleBlockOrTooLate ? null :
+      partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null : partitionInfoBeforeCheck;
     if (partitionInfo != null) {
       return new PushBlockStreamCallback(
         this, appShuffleInfo, streamId, partitionInfo, msg.mapIndex);
     } else {
-      // For a duplicate block or a block which is late, respond back with a callback that handles
-      // them differently.
+      // For a duplicate block or a block which is late or stale block from an older
+      // shuffleMergeId, respond back with a callback that handles them differently.
       return new StreamCallbackWithID() {
         @Override
         public String getID() {
@@ -379,11 +469,11 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
 
         @Override
         public void onComplete(String streamId) {
-          if (isTooLate) {
+          if (isStaleBlockOrTooLate) {
             // Throw an exception here so the block data is drained from channel and server
             // responds RpcFailure to the client.
             throw new RuntimeException(String.format("Block %s %s", streamId,
-              ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
+              ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
           }
           // For duplicate block that is received before the shuffle merge finalizes, the
           // server should respond success to the client.
@@ -397,9 +487,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
   }
 
   @Override
-  public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException {
-    logger.info("Finalizing shuffle {} from Application {}_{}.",
-      msg.shuffleId, msg.appId, msg.appAttemptId);
+  public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) {
+    logger.info("Finalizing shuffle {} with shuffleMergeId {} from Application {}_{}.",
+      msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId);
     AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId);
     if (appShuffleInfo.attemptId != msg.appAttemptId) {
       // If this Block belongs to a former application attempt, it is considered late,
@@ -410,17 +500,42 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
           + "with the current attempt id %s stored in shuffle service for application %s",
           msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
     }
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      appShuffleInfo.partitions.remove(msg.shuffleId);
+    AtomicReference<Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitionsRef =
+      new AtomicReference<>(null);
+    // Metadata of the determinate stage shuffle can be safely removed as part of finalizing
+    // shuffle merge. Currently once the shuffle is finalized for a determinate stages, retry
+    // stages of the same shuffle will have shuffle push disabled.
+    if (msg.shuffleMergeId == DETERMINATE_SHUFFLE_MERGE_ID) {
+      AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo =
+        appShuffleInfo.shuffles.remove(msg.shuffleId);
+      if (appShuffleMergePartitionsInfo != null) {
+        shuffleMergePartitionsRef.set(appShuffleMergePartitionsInfo.shuffleMergePartitions);
+      }
+    } else {
+      appShuffleInfo.shuffles.compute(msg.shuffleId, (id, value) -> {
+        if (null == value || msg.shuffleMergeId != value.shuffleMergeId ||
+          INDETERMINATE_SHUFFLE_FINALIZED == value.shuffleMergePartitions) {
+          throw new RuntimeException(String.format(
+            "Shuffle merge finalize request for shuffle %s with" + " shuffleMergeId %s is %s",
+            msg.shuffleId, msg.shuffleMergeId,
+            ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE_SUFFIX));
+        } else {
+          shuffleMergePartitionsRef.set(value.shuffleMergePartitions);
+          return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true);
+        }
+      });
+    }
+    Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions = shuffleMergePartitionsRef.get();
     MergeStatuses mergeStatuses;
-    if (shufflePartitions == null || shufflePartitions.isEmpty()) {
+    if (null == shuffleMergePartitions || shuffleMergePartitions.isEmpty()) {
       mergeStatuses =
-        new MergeStatuses(msg.shuffleId, new RoaringBitmap[0], new int[0], new long[0]);
+        new MergeStatuses(msg.shuffleId, msg.shuffleMergeId,
+          new RoaringBitmap[0], new int[0], new long[0]);
     } else {
-      List<RoaringBitmap> bitmaps = new ArrayList<>(shufflePartitions.size());
-      List<Integer> reduceIds = new ArrayList<>(shufflePartitions.size());
-      List<Long> sizes = new ArrayList<>(shufflePartitions.size());
-      for (AppShufflePartitionInfo partition: shufflePartitions.values()) {
+      List<RoaringBitmap> bitmaps = new ArrayList<>(shuffleMergePartitions.size());
+      List<Integer> reduceIds = new ArrayList<>(shuffleMergePartitions.size());
+      List<Long> sizes = new ArrayList<>(shuffleMergePartitions.size());
+      for (AppShufflePartitionInfo partition: shuffleMergePartitions.values()) {
         synchronized (partition) {
           try {
             // This can throw IOException which will marks this shuffle partition as not merged.
@@ -432,16 +547,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
             logger.warn("Exception while finalizing shuffle partition {}_{} {} {}", msg.appId,
               msg.appAttemptId, msg.shuffleId, partition.reduceId, ioe);
           } finally {
-            partition.closeAllFiles();
+            partition.closeAllFilesAndDeleteIfNeeded(false);
           }
         }
       }
-      mergeStatuses = new MergeStatuses(msg.shuffleId,
+      mergeStatuses = new MergeStatuses(msg.shuffleId, msg.shuffleMergeId,
         bitmaps.toArray(new RoaringBitmap[bitmaps.size()]), Ints.toArray(reduceIds),
         Longs.toArray(sizes));
     }
-    logger.info("Finalized shuffle {} from Application {}_{}.",
-      msg.shuffleId, msg.appId, msg.appAttemptId);
+    logger.info("Finalized shuffle {} with shuffleMergeId {} from Application {}_{}.",
+      msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId);
     return mergeStatuses;
   }
 
@@ -563,9 +678,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     private void writeBuf(ByteBuffer buf) throws IOException {
       while (buf.hasRemaining()) {
         long updatedPos = partitionInfo.getDataFilePos() + length;
-        logger.debug("{} shuffleId {} reduceId {} current pos {} updated pos {}",
-          partitionInfo.appId, partitionInfo.shuffleId,
-          partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos);
+        logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} current pos"
+          + " {} updated pos {}", partitionInfo.appId, partitionInfo.shuffleId,
+          partitionInfo.shuffleMergeId, partitionInfo.reduceId,
+          partitionInfo.getDataFilePos(), updatedPos);
         length += partitionInfo.dataChannel.write(buf, updatedPos);
       }
     }
@@ -631,6 +747,23 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       abortIfNecessary();
     }
 
+    /**
+     * If appShuffleMergePartitionsInfo is null or shuffleMergePartitions is set to
+     * INDETERMINATE_SHUFFLE_FINALIZED or if the reduceId is not in the map then the
+     * shuffle is already finalized. Therefore the block push is too late. If
+     * appShuffleMergePartitionsInfo's shuffleMergeId is
+     * greater than the request shuffleMergeId then it is a stale block push.
+     */
+    private boolean isStaleOrTooLate(
+        AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo,
+        int shuffleMergeId,
+        int reduceId) {
+      return null == appShuffleMergePartitionsInfo ||
+        INDETERMINATE_SHUFFLE_FINALIZED == appShuffleMergePartitionsInfo.shuffleMergePartitions ||
+          appShuffleMergePartitionsInfo.shuffleMergeId > shuffleMergeId ||
+          !appShuffleMergePartitionsInfo.shuffleMergePartitions.containsKey(reduceId);
+    }
+
     @Override
     public void onData(String streamId, ByteBuffer buf) throws IOException {
       // When handling the block data using StreamInterceptor, it can help to reduce the amount
@@ -648,14 +781,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       // to disk as well. This way, we avoid having to buffer the entirety of every blocks in
       // memory, while still providing the necessary guarantee.
       synchronized (partitionInfo) {
-        Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-          appShuffleInfo.partitions.get(partitionInfo.shuffleId);
-        // If the partitionInfo corresponding to (appId, shuffleId, reduceId) is no longer present
-        // then it means that the shuffle merge has already been finalized. We should thus ignore
-        // the data and just drain the remaining bytes of this message. This check should be
-        // placed inside the synchronized block to make sure that checking the key is still
-        // present and processing the data is atomic.
-        if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) {
+        if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
+            partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
           deferredBufs = null;
           return;
         }
@@ -668,8 +795,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
             return;
           }
           abortIfNecessary();
-          logger.trace("{} shuffleId {} reduceId {} onData writable",
-            partitionInfo.appId, partitionInfo.shuffleId,
+          logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData writable",
+            partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
             partitionInfo.reduceId);
           if (partitionInfo.getCurrentMapIndex() < 0) {
             partitionInfo.setCurrentMapIndex(mapIndex);
@@ -690,8 +817,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
             throw ioe;
           }
         } else {
-          logger.trace("{} shuffleId {} reduceId {} onData deferred",
-            partitionInfo.appId, partitionInfo.shuffleId,
+          logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData deferred",
+            partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
             partitionInfo.reduceId);
           // If we cannot write to disk, we buffer the current block chunk in memory so it could
           // potentially be written to disk later. We take our best effort without guarantee
@@ -725,19 +852,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     @Override
     public void onComplete(String streamId) throws IOException {
       synchronized (partitionInfo) {
-        logger.trace("{} shuffleId {} reduceId {} onComplete invoked",
-          partitionInfo.appId, partitionInfo.shuffleId,
+        logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onComplete invoked",
+          partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
           partitionInfo.reduceId);
-        Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-          appShuffleInfo.partitions.get(partitionInfo.shuffleId);
-        // When this request initially got to the server, the shuffle merge finalize request
-        // was not received yet. By the time we finish reading this message, the shuffle merge
-        // however is already finalized. We should thus respond RpcFailure to the client.
-        if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) {
+        // Initially when this request got to the server, the shuffle merge finalize request
+        // was not received yet or this was the latest stage attempt (or latest shuffleMergeId)
+        // generating shuffle output for the shuffle ID. By the time we finish reading this
+        // message, the block request is either stale or too late. We should thus respond
+        // RpcFailure to the client.
+        if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
+            partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
           deferredBufs = null;
-          throw new RuntimeException(String.format("Block %s %s", streamId,
-            ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
+          throw new RuntimeException(String.format("Block %s is %s", streamId,
+            ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
         }
+
         // Check if we can commit this block
         if (allowedToWrite()) {
           // Identify duplicate block generated by speculative tasks. We respond success to
@@ -800,21 +929,20 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
         logger.debug("Encountered issue when merging {}", streamId, throwable);
       }
       // Only update partitionInfo if the failure corresponds to a valid request. If the
-      // request is too late, i.e. received after shuffle merge finalize, #onFailure will
-      // also be triggered, and we can just ignore. Also, if we couldn't find an opportunity
-      // to write the block data to disk, we should also ignore here.
+      // request is too late, i.e. received after shuffle merge finalize or stale block push,
+      // #onFailure will also be triggered, and we can just ignore. Also, if we couldn't find
+      // an opportunity to write the block data to disk, we should also ignore here.
       if (isWriting) {
         synchronized (partitionInfo) {
-          Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-            appShuffleInfo.partitions.get(partitionInfo.shuffleId);
-          if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) {
-            logger.debug("{} shuffleId {} reduceId {} encountered failure",
-              partitionInfo.appId, partitionInfo.shuffleId,
-              partitionInfo.reduceId);
-            partitionInfo.setCurrentMapIndex(-1);
+          if (!isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
+              partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
+              logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {}"
+                + " encountered failure", partitionInfo.appId, partitionInfo.shuffleId,
+                partitionInfo.shuffleMergeId, partitionInfo.reduceId);
+              partitionInfo.setCurrentMapIndex(-1);
+            }
           }
         }
-      }
       isWriting = false;
     }
 
@@ -824,12 +952,35 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     }
   }
 
+  /**
+   * Wrapper class to hold merged Shuffle related information for a specific shuffleMergeId
+   * required for the shuffles of indeterminate stages.
+   */
+  public static class AppShuffleMergePartitionsInfo {
+    private final int shuffleMergeId;
+    private final Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions;
+
+    public AppShuffleMergePartitionsInfo(
+        int shuffleMergeId, boolean shuffleFinalized) {
+      this.shuffleMergeId = shuffleMergeId;
+      this.shuffleMergePartitions = shuffleFinalized ?
+        INDETERMINATE_SHUFFLE_FINALIZED : new ConcurrentHashMap<>();
+    }
+
+    @VisibleForTesting
+    public Map<Integer, AppShufflePartitionInfo> getShuffleMergePartitions() {
+      return shuffleMergePartitions;
+    }
+  }
+
   /** Metadata tracked for an actively merged shuffle partition */
   public static class AppShufflePartitionInfo {
 
     private final String appId;
     private final int shuffleId;
+    private final int shuffleMergeId;
     private final int reduceId;
+    private final File dataFile;
     // The merged shuffle data file channel
     public final FileChannel dataChannel;
     // The index file for a particular merged shuffle contains the chunk offsets.
@@ -854,6 +1005,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     AppShufflePartitionInfo(
         String appId,
         int shuffleId,
+        int shuffleMergeId,
         int reduceId,
         File dataFile,
         MergeShuffleFile indexFile,
@@ -861,8 +1013,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       Preconditions.checkArgument(appId != null, "app id is null");
       this.appId = appId;
       this.shuffleId = shuffleId;
+      this.shuffleMergeId = shuffleMergeId;
       this.reduceId = reduceId;
       this.dataChannel = new FileOutputStream(dataFile).getChannel();
+      this.dataFile = dataFile;
       this.indexFile = indexFile;
       this.metaFile = metaFile;
       this.currentMapIndex = -1;
@@ -878,8 +1032,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     }
 
     public void setDataFilePos(long dataFilePos) {
-      logger.trace("{} shuffleId {} reduceId {} current pos {} update pos {}", appId,
-        shuffleId, reduceId, this.dataFilePos, dataFilePos);
+      logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} current pos {}"
+        + " update pos {}", appId, shuffleId, shuffleMergeId, reduceId, this.dataFilePos,
+        dataFilePos);
       this.dataFilePos = dataFilePos;
     }
 
@@ -888,8 +1043,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     }
 
     void setCurrentMapIndex(int mapIndex) {
-      logger.trace("{} shuffleId {} reduceId {} updated mapIndex {} current mapIndex {}",
-        appId, shuffleId, reduceId, currentMapIndex, mapIndex);
+      logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} updated mapIndex {}"
+        + " current mapIndex {}", appId, shuffleId, shuffleMergeId, reduceId,
+        currentMapIndex, mapIndex);
       this.currentMapIndex = mapIndex;
     }
 
@@ -898,8 +1054,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     }
 
     void blockMerged(int mapIndex) {
-      logger.debug("{} shuffleId {} reduceId {} updated merging mapIndex {}", appId,
-        shuffleId, reduceId, mapIndex);
+      logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} updated merging mapIndex {}",
+        appId, shuffleId, shuffleMergeId, reduceId, mapIndex);
       mapTracker.add(mapIndex);
       chunkTracker.add(mapIndex);
       lastMergedMapIndex = mapIndex;
@@ -917,8 +1073,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
      */
     void updateChunkInfo(long chunkOffset, int mapIndex) throws IOException {
       try {
-        logger.trace("{} shuffleId {} reduceId {} index current {} updated {}",
-          appId, shuffleId, reduceId, this.lastChunkOffset, chunkOffset);
+        logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} index current {}"
+          + " updated {}", appId, shuffleId, shuffleMergeId, reduceId,
+          this.lastChunkOffset, chunkOffset);
         if (indexMetaUpdateFailed) {
           indexFile.getChannel().position(indexFile.getPos());
         }
@@ -946,8 +1103,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
         return;
       }
       chunkTracker.add(mapIndex);
-      logger.trace("{} shuffleId {} reduceId {} mapIndex {} write chunk to meta file",
-        appId, shuffleId, reduceId, mapIndex);
+      logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} mapIndex {}"
+        + " write chunk to meta file", appId, shuffleId, shuffleMergeId, reduceId, mapIndex);
       if (indexMetaUpdateFailed) {
         metaFile.getChannel().position(metaFile.getPos());
       }
@@ -980,32 +1137,41 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       metaFile.getChannel().truncate(metaFile.getPos());
     }
 
-    void closeAllFiles() {
+    void closeAllFilesAndDeleteIfNeeded(boolean delete) {
       try {
         if (dataChannel.isOpen()) {
           dataChannel.close();
+          if (delete) {
+            dataFile.delete();
+          }
         }
       } catch (IOException ioe) {
-        logger.warn("Error closing data channel for {} shuffleId {} reduceId {}",
-          appId, shuffleId, reduceId);
+        logger.warn("Error closing data channel for {} shuffleId {} shuffleMergeId {}"
+          + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
       }
       try {
         metaFile.close();
+        if (delete) {
+          metaFile.delete();
+        }
       } catch (IOException ioe) {
-        logger.warn("Error closing meta file for {} shuffleId {} reduceId {}",
-          appId, shuffleId, reduceId);
-      }
+        logger.warn("Error closing meta file for {} shuffleId {} shuffleMergeId {}"
+          + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
+        }
       try {
         indexFile.close();
+        if (delete) {
+          indexFile.delete();
+        }
       } catch (IOException ioe) {
-        logger.warn("Error closing index file for {} shuffleId {} reduceId {}",
-          appId, shuffleId, reduceId);
+        logger.warn("Error closing index file for {} shuffleId {} shuffleMergeId {}"
+          + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
       }
     }
 
     @Override
     protected void finalize() throws Throwable {
-      closeAllFiles();
+      closeAllFilesAndDeleteIfNeeded(false);
     }
 
     @VisibleForTesting
@@ -1065,7 +1231,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     private final String appId;
     private final int attemptId;
     private final AppPathsInfo appPathsInfo;
-    private final ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions;
+    /**
+     * 1. Key tracks shuffleId for an application
+     * 2. Value tracks the AppShuffleMergePartitionsInfo having shuffleMergeId and
+     * a Map tracking AppShufflePartitionInfo for all the shuffle partitions.
+     */
+    private final ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> shuffles;
 
     AppShuffleInfo(
         String appId,
@@ -1074,12 +1245,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       this.appId = appId;
       this.attemptId = attemptId;
       this.appPathsInfo = appPathsInfo;
-      partitions = new ConcurrentHashMap<>();
+      shuffles = new ConcurrentHashMap<>();
     }
 
     @VisibleForTesting
-    public ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> getPartitions() {
-      return partitions;
+    public ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> getShuffles() {
+      return shuffles;
     }
 
     /**
@@ -1098,29 +1269,37 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     private String generateFileName(
         String appId,
         int shuffleId,
+        int shuffleMergeId,
         int reduceId) {
       return String.format(
-        "%s_%s_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId, reduceId);
+        "%s_%s_%d_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId,
+          shuffleMergeId, reduceId);
     }
 
     public File getMergedShuffleDataFile(
         int shuffleId,
+        int shuffleMergeId,
         int reduceId) {
-      String fileName = String.format("%s.data", generateFileName(appId, shuffleId, reduceId));
+      String fileName = String.format("%s.data", generateFileName(appId, shuffleId,
+        shuffleMergeId, reduceId));
       return getFile(fileName);
     }
 
     public File getMergedShuffleIndexFile(
         int shuffleId,
+        int shuffleMergeId,
         int reduceId) {
-      String indexName = String.format("%s.index", generateFileName(appId, shuffleId, reduceId));
+      String indexName = String.format("%s.index", generateFileName(appId, shuffleId,
+        shuffleMergeId, reduceId));
       return getFile(indexName);
     }
 
     public File getMergedShuffleMetaFile(
         int shuffleId,
+        int shuffleMergeId,
         int reduceId) {
-      String metaName = String.format("%s.meta", generateFileName(appId, shuffleId, reduceId));
+      String metaName = String.format("%s.meta", generateFileName(appId, shuffleId,
+        shuffleMergeId, reduceId));
       return getFile(metaName);
     }
   }
@@ -1130,18 +1309,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
     private final FileChannel channel;
     private final DataOutputStream dos;
     private long pos;
+    private File file;
 
     @VisibleForTesting
     MergeShuffleFile(File file) throws IOException {
       FileOutputStream fos = new FileOutputStream(file);
       channel = fos.getChannel();
       dos = new DataOutputStream(fos);
+      this.file = file;
     }
 
     @VisibleForTesting
     MergeShuffleFile(FileChannel channel, DataOutputStream dos) {
       this.channel = channel;
       this.dos = dos;
+      this.file = null;
     }
 
     private void updatePos(long numBytes) {
@@ -1154,6 +1336,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       }
     }
 
+    void delete() throws IOException {
+      try {
+        if (null != file) {
+          file.delete();
+        }
+      } finally {
+        file = null;
+      }
+    }
+
     @VisibleForTesting
     DataOutputStream getDos() {
       return dos;
@@ -1169,4 +1361,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
       return pos;
     }
   }
+
+  public static class StaleBlockPushException extends RuntimeException {
+    public StaleBlockPushException(String message) {
+      super(message);
+    }
+  }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
index 27345dd..cf4cbcf 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
@@ -37,14 +37,19 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
   public final int[] reduceIds;
   // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds.
   public final int[][] chunkIds;
+  // shuffleMergeId is used to uniquely identify merging process of shuffle by
+  // an indeterminate stage attempt.
+  public final int shuffleMergeId;
 
   public FetchShuffleBlockChunks(
       String appId,
       String execId,
       int shuffleId,
+      int shuffleMergeId,
       int[] reduceIds,
       int[][] chunkIds) {
     super(appId, execId, shuffleId);
+    this.shuffleMergeId = shuffleMergeId;
     this.reduceIds = reduceIds;
     this.chunkIds = chunkIds;
     assert(reduceIds.length == chunkIds.length);
@@ -56,6 +61,7 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
   @Override
   public String toString() {
     return toStringHelper()
+      .append("shuffleMergeId", shuffleMergeId)
       .append("reduceIds", Arrays.toString(reduceIds))
       .append("chunkIds", Arrays.deepToString(chunkIds))
       .toString();
@@ -68,13 +74,16 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
 
     FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
     if (!super.equals(that)) return false;
-    if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
+    if (shuffleMergeId != that.shuffleMergeId ||
+      !Arrays.equals(reduceIds, that.reduceIds)) {
+      return false;
+    }
     return Arrays.deepEquals(chunkIds, that.chunkIds);
   }
 
   @Override
   public int hashCode() {
-    int result = super.hashCode();
+    int result = super.hashCode() * 31 + shuffleMergeId;
     result = 31 * result + Arrays.hashCode(reduceIds);
     result = 31 * result + Arrays.deepHashCode(chunkIds);
     return result;
@@ -89,12 +98,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
     return super.encodedLength()
       + Encoders.IntArrays.encodedLength(reduceIds)
       + 4 /* encoded length of chunkIds.size() */
+      + 4 /* encoded length of shuffleMergeId */
       + encodedLengthOfChunkIds;
   }
 
   @Override
   public void encode(ByteBuf buf) {
     super.encode(buf);
+    buf.writeInt(shuffleMergeId);
     Encoders.IntArrays.encode(buf, reduceIds);
     // Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the
     // interest of forward compatibility.
@@ -117,12 +128,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
     String appId = Encoders.Strings.decode(buf);
     String execId = Encoders.Strings.decode(buf);
     int shuffleId = buf.readInt();
+    int shuffleMergeId = buf.readInt();
     int[] reduceIds = Encoders.IntArrays.decode(buf);
     int chunkIdsLen = buf.readInt();
     int[][] chunkIds = new int[chunkIdsLen][];
     for (int i = 0; i < chunkIdsLen; i++) {
       chunkIds[i] = Encoders.IntArrays.decode(buf);
     }
-    return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds);
+    return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds,
+      chunkIds);
   }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
index 088ff38..675739a 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
@@ -34,14 +34,17 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
   public final String appId;
   public final int appAttemptId;
   public final int shuffleId;
+  public final int shuffleMergeId;
 
   public FinalizeShuffleMerge(
       String appId,
       int appAttemptId,
-      int shuffleId) {
+      int shuffleId,
+      int shuffleMergeId) {
     this.appId = appId;
     this.appAttemptId = appAttemptId;
     this.shuffleId = shuffleId;
+    this.shuffleMergeId = shuffleMergeId;
   }
 
   @Override
@@ -51,7 +54,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(appId, appAttemptId, shuffleId);
+    return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId);
   }
 
   @Override
@@ -60,6 +63,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
       .append("appId", appId)
       .append("attemptId", appAttemptId)
       .append("shuffleId", shuffleId)
+      .append("shuffleMergeId", shuffleMergeId)
       .toString();
   }
 
@@ -69,14 +73,15 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
       FinalizeShuffleMerge o = (FinalizeShuffleMerge) other;
       return Objects.equal(appId, o.appId)
         && appAttemptId == o.appAttemptId
-        && shuffleId == o.shuffleId;
+        && shuffleId == o.shuffleId
+        && shuffleMergeId == o.shuffleMergeId;
     }
     return false;
   }
 
   @Override
   public int encodedLength() {
-    return Encoders.Strings.encodedLength(appId) + 4 + 4;
+    return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4;
   }
 
   @Override
@@ -84,12 +89,14 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
     Encoders.Strings.encode(buf, appId);
     buf.writeInt(appAttemptId);
     buf.writeInt(shuffleId);
+    buf.writeInt(shuffleMergeId);
   }
 
   public static FinalizeShuffleMerge decode(ByteBuf buf) {
     String appId = Encoders.Strings.decode(buf);
     int attemptId = buf.readInt();
     int shuffleId = buf.readInt();
-    return new FinalizeShuffleMerge(appId, attemptId, shuffleId);
+    int shuffleMergeId = buf.readInt();
+    return new FinalizeShuffleMerge(appId, attemptId, shuffleId, shuffleMergeId);
   }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
index 142ab73..b2658d6 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
@@ -41,6 +41,11 @@ public class MergeStatuses extends BlockTransferMessage {
   /** Shuffle ID **/
   public final int shuffleId;
   /**
+   * shuffleMergeId is used to uniquely identify merging process of shuffle by
+   * an indeterminate stage attempt.
+   */
+  public final int shuffleMergeId;
+  /**
    * Array of bitmaps tracking the set of mapper partition blocks merged for each
    * reducer partition
    */
@@ -55,10 +60,12 @@ public class MergeStatuses extends BlockTransferMessage {
 
   public MergeStatuses(
       int shuffleId,
+      int shuffleMergeId,
       RoaringBitmap[] bitmaps,
       int[] reduceIds,
       long[] sizes) {
     this.shuffleId = shuffleId;
+    this.shuffleMergeId = shuffleMergeId;
     this.bitmaps = bitmaps;
     this.reduceIds = reduceIds;
     this.sizes = sizes;
@@ -71,7 +78,8 @@ public class MergeStatuses extends BlockTransferMessage {
 
   @Override
   public int hashCode() {
-    int objectHashCode = Objects.hashCode(shuffleId);
+    int objectHashCode = Objects.hashCode(shuffleId) * 41 +
+        Objects.hashCode(shuffleMergeId);
     return (objectHashCode * 41 + Arrays.hashCode(reduceIds) * 41
       + Arrays.hashCode(bitmaps) * 41 + Arrays.hashCode(sizes));
   }
@@ -80,6 +88,7 @@ public class MergeStatuses extends BlockTransferMessage {
   public String toString() {
     return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
       .append("shuffleId", shuffleId)
+      .append("shuffleMergeId", shuffleMergeId)
       .append("reduceId size", reduceIds.length)
       .toString();
   }
@@ -89,6 +98,7 @@ public class MergeStatuses extends BlockTransferMessage {
     if (other != null && other instanceof MergeStatuses) {
       MergeStatuses o = (MergeStatuses) other;
       return Objects.equal(shuffleId, o.shuffleId)
+        && Objects.equal(shuffleMergeId, o.shuffleMergeId)
         && Arrays.equals(bitmaps, o.bitmaps)
         && Arrays.equals(reduceIds, o.reduceIds)
         && Arrays.equals(sizes, o.sizes);
@@ -98,7 +108,7 @@ public class MergeStatuses extends BlockTransferMessage {
 
   @Override
   public int encodedLength() {
-    return 4 // int
+    return 4 + 4 // shuffleId and shuffleMergeId
       + Encoders.BitmapArrays.encodedLength(bitmaps)
       + Encoders.IntArrays.encodedLength(reduceIds)
       + Encoders.LongArrays.encodedLength(sizes);
@@ -107,6 +117,7 @@ public class MergeStatuses extends BlockTransferMessage {
   @Override
   public void encode(ByteBuf buf) {
     buf.writeInt(shuffleId);
+    buf.writeInt(shuffleMergeId);
     Encoders.BitmapArrays.encode(buf, bitmaps);
     Encoders.IntArrays.encode(buf, reduceIds);
     Encoders.LongArrays.encode(buf, sizes);
@@ -114,9 +125,10 @@ public class MergeStatuses extends BlockTransferMessage {
 
   public static MergeStatuses decode(ByteBuf buf) {
     int shuffleId = buf.readInt();
+    int shuffleMergeId = buf.readInt();
     RoaringBitmap[] bitmaps = Encoders.BitmapArrays.decode(buf);
     int[] reduceIds = Encoders.IntArrays.decode(buf);
     long[] sizes = Encoders.LongArrays.decode(buf);
-    return new MergeStatuses(shuffleId, bitmaps, reduceIds, sizes);
+    return new MergeStatuses(shuffleId, shuffleMergeId, bitmaps, reduceIds, sizes);
   }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java
index d5e1cf2..b868d7c 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java
@@ -37,6 +37,7 @@ public class PushBlockStream extends BlockTransferMessage {
   public final String appId;
   public final int appAttemptId;
   public final int shuffleId;
+  public final int shuffleMergeId;
   public final int mapIndex;
   public final int reduceId;
   // Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of
@@ -47,12 +48,14 @@ public class PushBlockStream extends BlockTransferMessage {
       String appId,
       int appAttemptId,
       int shuffleId,
+      int shuffleMergeId,
       int mapIndex,
       int reduceId,
       int index) {
     this.appId = appId;
     this.appAttemptId = appAttemptId;
     this.shuffleId = shuffleId;
+    this.shuffleMergeId = shuffleMergeId;
     this.mapIndex = mapIndex;
     this.reduceId = reduceId;
     this.index = index;
@@ -65,7 +68,8 @@ public class PushBlockStream extends BlockTransferMessage {
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(appId, appAttemptId, shuffleId, mapIndex , reduceId, index);
+    return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId, mapIndex , reduceId,
+      index);
   }
 
   @Override
@@ -74,6 +78,7 @@ public class PushBlockStream extends BlockTransferMessage {
       .append("appId", appId)
       .append("attemptId", appAttemptId)
       .append("shuffleId", shuffleId)
+      .append("shuffleMergeId", shuffleMergeId)
       .append("mapIndex", mapIndex)
       .append("reduceId", reduceId)
       .append("index", index)
@@ -87,6 +92,7 @@ public class PushBlockStream extends BlockTransferMessage {
       return Objects.equal(appId, o.appId)
         && appAttemptId == o.appAttemptId
         && shuffleId == o.shuffleId
+        && shuffleMergeId == o.shuffleMergeId
         && mapIndex == o.mapIndex
         && reduceId == o.reduceId
         && index == o.index;
@@ -96,7 +102,7 @@ public class PushBlockStream extends BlockTransferMessage {
 
   @Override
   public int encodedLength() {
-    return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4;
+    return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4 + 4;
   }
 
   @Override
@@ -104,6 +110,7 @@ public class PushBlockStream extends BlockTransferMessage {
     Encoders.Strings.encode(buf, appId);
     buf.writeInt(appAttemptId);
     buf.writeInt(shuffleId);
+    buf.writeInt(shuffleMergeId);
     buf.writeInt(mapIndex);
     buf.writeInt(reduceId);
     buf.writeInt(index);
@@ -113,9 +120,11 @@ public class PushBlockStream extends BlockTransferMessage {
     String appId = Encoders.Strings.decode(buf);
     int attemptId = buf.readInt();
     int shuffleId = buf.readInt();
+    int shuffleMergeId = buf.readInt();
     int mapIdx = buf.readInt();
     int reduceId = buf.readInt();
     int index = buf.readInt();
-    return new PushBlockStream(appId, attemptId, shuffleId, mapIdx, reduceId, index);
+    return new PushBlockStream(appId, attemptId, shuffleId, shuffleMergeId, mapIdx, reduceId,
+      index);
   }
 }
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java
index 992e776..c8066d1 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java
@@ -29,23 +29,31 @@ import static org.junit.Assert.*;
 public class ErrorHandlerSuite {
 
   @Test
-  public void testPushErrorRetry() {
-    ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
-    assertFalse(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
-      ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
-    assertFalse(handler.shouldRetryError(new RuntimeException(new ConnectException())));
-    assertTrue(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+  public void testErrorRetry() {
+    ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler();
+    assertFalse(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+      ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))));
+    assertFalse(pushHandler.shouldRetryError(new RuntimeException(new ConnectException())));
+    assertTrue(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
       ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
-    assertTrue(handler.shouldRetryError(new Throwable()));
+    assertTrue(pushHandler.shouldRetryError(new Throwable()));
+
+    ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler();
+    assertFalse(fetchHandler.shouldRetryError(new RuntimeException(
+      ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)));
   }
 
   @Test
-  public void testPushErrorLogging() {
-    ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
-    assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
-      ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
-    assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
+  public void testErrorLogging() {
+    ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler();
+    assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException(
+      ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))));
+    assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException(
       ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
-    assertTrue(handler.shouldLogError(new Throwable()));
+    assertTrue(pushHandler.shouldLogError(new Throwable()));
+
+    ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler();
+    assertFalse(fetchHandler.shouldLogError(new RuntimeException(
+      ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)));
   }
 }
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
index 00756b1..9e0b3c6 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
@@ -243,9 +243,9 @@ public class ExternalBlockHandlerSuite {
   public void testFinalizeShuffleMerge() throws IOException {
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
 
-    FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0);
+    FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0, 0);
     RoaringBitmap bitmap = RoaringBitmap.bitmapOf(0, 1, 2);
-    MergeStatuses statuses = new MergeStatuses(0, new RoaringBitmap[]{bitmap},
+    MergeStatuses statuses = new MergeStatuses(0, 0, new RoaringBitmap[]{bitmap},
       new int[]{3}, new long[]{30});
     when(mergedShuffleManager.finalizeShuffleMerge(req)).thenReturn(statuses);
 
@@ -269,22 +269,22 @@ public class ExternalBlockHandlerSuite {
 
   @Test
   public void testFetchMergedBlocksMeta() {
-    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn(
+    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 0)).thenReturn(
       new MergedBlockMeta(1, mock(ManagedBuffer.class)));
-    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn(
+    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 1)).thenReturn(
       new MergedBlockMeta(3, mock(ManagedBuffer.class)));
-    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn(
+    when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 2)).thenReturn(
       new MergedBlockMeta(5, mock(ManagedBuffer.class)));
 
     int[] expectedCount = new int[]{1, 3, 5};
     String appId = "app0";
     long requestId = 0L;
     for (int reduceId = 0; reduceId < 3; reduceId++) {
-      MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId);
+      MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, 0, reduceId);
       MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
       handler.getMergedBlockMetaReqHandler()
         .receiveMergeBlockMetaReq(client, req, callback);
-      verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId);
+      verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, 0, reduceId);
 
       ArgumentCaptor<Integer> numChunksResponse = ArgumentCaptor.forClass(Integer.class);
       ArgumentCaptor<ManagedBuffer> chunkBitmapResponse =
@@ -313,12 +313,12 @@ public class ExternalBlockHandlerSuite {
     if (useOpenBlocks) {
       OpenBlocks openBlocks =
         new OpenBlocks("app0", "exec1",
-          new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0",
-            "shuffleChunk_0_1_1"});
+          new String[] {"shuffleChunk_0_0_0_0", "shuffleChunk_0_0_0_1", "shuffleChunk_0_0_1_0",
+            "shuffleChunk_0_0_1_1"});
       buffer = openBlocks.toByteBuffer();
     } else {
       FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks(
-        "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
+        "app0", "exec1", 0, 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
       buffer = fetchChunks.toByteBuffer();
     }
     ManagedBuffer[][] buffers = new ManagedBuffer[][] {
@@ -334,7 +334,7 @@ public class ExternalBlockHandlerSuite {
     for (int reduceId = 0; reduceId < 2; reduceId++) {
       for (int chunkId = 0; chunkId < 2; chunkId++) {
         when(mergedShuffleManager.getMergedBlockData(
-          "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
+          "app0", 0, 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
       }
     }
     handler.receive(client, buffer, callback);
@@ -356,11 +356,12 @@ public class ExternalBlockHandlerSuite {
       }
     }
     assertFalse(bufferIter.hasNext());
-    verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt());
+    verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt(),
+        anyInt());
     verify(blockResolver, never()).getBlockData(
       anyString(), anyString(), anyInt(), anyInt(), anyInt());
-    verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0);
-    verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1);
+    verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 0);
+    verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 1);
 
     // Verify open block request latency metrics
     Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index 8e2ddf5..cc4640d 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -244,51 +244,49 @@ public class OneForOneBlockFetcherSuite {
   @Test
   public void testShuffleBlockChunksFetch() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
-    blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
-    blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+    blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("shuffleChunk_0_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+    blocks.put("shuffleChunk_0_0_0_2",
+      new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
-      new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
+      new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 },
         new int[][] {{ 0, 1, 2 }}), conf);
     for (int i = 0; i < 3; i ++) {
-      verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i,
-        blocks.get("shuffleChunk_0_0_" + i));
+      verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_" + i,
+        blocks.get("shuffleChunk_0_0_0_" + i));
     }
   }
 
   @Test
   public void testShuffleBlockChunkFetchFailure() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
-    blocks.put("shuffleChunk_0_0_1", null);
-    blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+    blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("shuffleChunk_0_0_0_1", null);
+    blocks.put("shuffleChunk_0_0_0_2",
+      new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
-      new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}),
+      new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[]{0}, new int[][]{{0, 1, 2}}),
         conf);
-    verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0",
-      blocks.get("shuffleChunk_0_0_0"));
-    verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any());
-    verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2",
-      blocks.get("shuffleChunk_0_0_2"));
+    verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_0",
+      blocks.get("shuffleChunk_0_0_0_0"));
+    verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_0_1"), any());
+    verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_2",
+      blocks.get("shuffleChunk_0_0_0_2"));
   }
 
   @Test
   public void testInvalidShuffleBlockIds() {
     assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
       new String[]{"shuffle_0_0"},
-      new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
-        new int[][] {{ 0 }}), conf));
+      new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 },
+        new int[][] {{ 0 }}, false), conf));
     assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
       new String[]{"shuffleChunk_0_0_0_0_0"},
-      new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
-        new int[][] {{ 0 }}), conf));
-    assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
-      new String[]{"shuffleChunk_0_0_0_0"},
-      new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
+      new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 },
         new int[][] {{ 0 }}), conf));
   }
 
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java
index f709a56..d2fd5d9 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java
@@ -45,77 +45,78 @@ public class OneForOneBlockPusherSuite {
   @Test
   public void testPushOne() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
+    blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockPushingListener listener = pushBlocks(
       blocks,
       blockIds,
-      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0)));
+      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0)));
 
-    verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
+    verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
   }
 
   @Test
   public void testPushThree() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
-    blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
-    blocks.put("shufflePush_0_2_0", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+    blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+    blocks.put("shufflePush_0_0_2_0",
+      new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockPushingListener listener = pushBlocks(
       blocks,
       blockIds,
-      Arrays.asList(new PushBlockStream("app-id",0,  0, 0, 0, 0),
-        new PushBlockStream("app-id", 0, 0, 1, 0, 1),
-        new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
+      Arrays.asList(new PushBlockStream("app-id",0,  0, 0, 0, 0, 0),
+        new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
+        new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
 
-    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
-    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_1_0"), any());
-    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_2_0"), any());
+    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
+    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_1_0"), any());
+    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_2_0"), any());
   }
 
   @Test
   public void testServerFailures() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
-    blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
-    blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+    blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+    blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockPushingListener listener = pushBlocks(
       blocks,
       blockIds,
-      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0),
-        new PushBlockStream("app-id", 0, 0, 1, 0, 1),
-        new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
+      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0),
+        new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
+        new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
 
-    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
-    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any());
-    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_2_0"), any());
+    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
+    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any());
+    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any());
   }
 
   @Test
   public void testHandlingRetriableFailures() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
-    blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
-    blocks.put("shufflePush_0_1_0", null);
-    blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+    blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("shufflePush_0_0_1_0", null);
+    blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
 
     BlockPushingListener listener = pushBlocks(
       blocks,
       blockIds,
-      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0),
-        new PushBlockStream("app-id", 0, 0, 1, 0, 1),
-        new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
-
-    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
-    verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0")), any());
-    verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0"), any());
-    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any());
-    verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_2_0"), any());
+      Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0),
+        new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
+        new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
+
+    verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
+    verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0_0")), any());
+    verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0_0"), any());
+    verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any());
+    verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any());
   }
 
   /**
@@ -147,7 +148,7 @@ public class OneForOneBlockPusherSuite {
           + ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX));
       } else {
         callback.onFailure(new RuntimeException("Quick fail " + entry.getKey()
-          + ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
+          + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
       }
       assertEquals(msgIterator.next(), message);
       return null;
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
index 3a7c14c..c69e57d 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
@@ -106,7 +106,7 @@ public class RemoteBlockPushResolverSuite {
   @Test(expected = RuntimeException.class)
   public void testNoIndexFile() {
     try {
-      pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
+      pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
     } catch (Throwable t) {
       assertTrue(t.getMessage().startsWith("Merged shuffle index file"));
       Throwables.propagate(t);
@@ -116,58 +116,58 @@ public class RemoteBlockPushResolverSuite {
   @Test
   public void testBasicBlockMerge() throws IOException {
     PushBlock[] pushBlocks = new PushBlock[] {
-      new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4])),
-      new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[5]))
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4])),
+      new PushBlock(0,  0, 1, 0, ByteBuffer.wrap(new byte[5]))
     };
     pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks);
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {0}, new long[] {9});
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}});
   }
 
   @Test
   public void testDividingMergedBlocksIntoChunks() throws IOException {
     PushBlock[] pushBlocks = new PushBlock[] {
-      new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])),
-      new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])),
-      new PushBlock(0, 2, 0, ByteBuffer.wrap(new byte[5])),
-      new PushBlock(0, 3, 0, ByteBuffer.wrap(new byte[3]))
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])),
+      new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])),
+      new PushBlock(0, 0, 2, 0, ByteBuffer.wrap(new byte[5])),
+      new PushBlock(0, 0, 3, 0, ByteBuffer.wrap(new byte[3]))
     };
     pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks);
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {0}, new long[] {13});
-    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, meta, new int[]{5, 5, 3}, new int[][]{{0, 1}, {2}, {3}});
+    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, meta, new int[]{5, 5, 3}, new int[][]{{0, 1}, {2}, {3}});
   }
 
   @Test
   public void testFinalizeWithMultipleReducePartitions() throws IOException {
     PushBlock[] pushBlocks = new PushBlock[] {
-      new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])),
-      new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])),
-      new PushBlock(0, 0, 1, ByteBuffer.wrap(new byte[5])),
-      new PushBlock(0, 1, 1, ByteBuffer.wrap(new byte[3]))
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])),
+      new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])),
+      new PushBlock(0, 0, 0, 1, ByteBuffer.wrap(new byte[5])),
+      new PushBlock(0, 0, 1, 1, ByteBuffer.wrap(new byte[3]))
     };
     pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks);
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {0, 1}, new long[] {5, 8});
-    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, meta, new int[]{5}, new int[][]{{0, 1}});
+    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, meta, new int[]{5}, new int[][]{{0, 1}});
   }
 
   @Test
   public void testDeferredBufsAreWrittenDuringOnData() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     // This should be deferred
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3]));
     // stream 1 now completes
@@ -176,20 +176,20 @@ public class RemoteBlockPushResolverSuite {
     // stream 2 has more data and then completes
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3]));
     stream2.onComplete(stream2.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}});
   }
 
   @Test
   public void testDeferredBufsAreWrittenDuringOnComplete() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     // This should be deferred
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3]));
@@ -198,40 +198,40 @@ public class RemoteBlockPushResolverSuite {
     stream1.onComplete(stream1.getID());
     // stream 2 now completes completes
     stream2.onComplete(stream2.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}});
   }
 
   @Test
   public void testDuplicateBlocksAreIgnoredWhenPrevStreamHasCompleted() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onComplete(stream1.getID());
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     // This should be ignored
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onComplete(stream2.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
   }
 
   @Test
   public void testDuplicateBlocksAreIgnoredWhenPrevStreamIsInProgress() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     // This should be ignored
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
@@ -240,20 +240,20 @@ public class RemoteBlockPushResolverSuite {
     stream1.onComplete(stream1.getID());
     // stream 2 now completes completes
     stream2.onComplete(stream2.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
   }
 
   @Test
   public void testFailureAfterData() throws IOException {
     StreamCallbackWithID stream =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4]));
     stream.onFailure(stream.getID(), new RuntimeException("Forced Failure"));
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
     assertEquals("num-chunks", 0, blockMeta.getNumChunks());
   }
 
@@ -261,13 +261,13 @@ public class RemoteBlockPushResolverSuite {
   public void testFailureAfterMultipleDataBlocks() throws IOException {
     StreamCallbackWithID stream =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2]));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3]));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4]));
     stream.onFailure(stream.getID(), new RuntimeException("Forced Failure"));
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
     assertEquals("num-chunks", 0, blockMeta.getNumChunks());
   }
 
@@ -275,15 +275,15 @@ public class RemoteBlockPushResolverSuite {
   public void testFailureAfterComplete() throws IOException {
     StreamCallbackWithID stream =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2]));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3]));
     stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4]));
     stream.onComplete(stream.getID());
     stream.onFailure(stream.getID(), new RuntimeException("Forced Failure"));
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
   }
 
   @Test(expected = RuntimeException.class)
@@ -293,22 +293,23 @@ public class RemoteBlockPushResolverSuite {
       ByteBuffer.wrap(new byte[5])
     };
     StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     for (ByteBuffer block : blocks) {
       stream.onData(stream.getID(), block);
     }
     stream.onComplete(stream.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4]));
     try {
       stream1.onComplete(stream1.getID());
     } catch (RuntimeException re) {
-      assertEquals(
-        "Block shufflePush_0_1_0 received after merged shuffle is finalized", re.getMessage());
-      MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-      validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
+      assertEquals("Block shufflePush_0_0_1_0 received after merged shuffle is finalized or stale"
+        + " block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being"
+          + " pushed", re.getMessage());
+      MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+      validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
       throw re;
     }
   }
@@ -321,7 +322,7 @@ public class RemoteBlockPushResolverSuite {
 
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     byte[] data = new byte[10];
     ThreadLocalRandom.current().nextBytes(data);
     stream1.onData(stream1.getID(), ByteBuffer.wrap(data));
@@ -329,21 +330,21 @@ public class RemoteBlockPushResolverSuite {
     stream1.onFailure(stream1.getID(), new RuntimeException("forced error"));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     ByteBuffer nextBuf= ByteBuffer.wrap(expectedBytes, 0, 2);
     stream2.onData(stream2.getID(), nextBuf);
     stream2.onComplete(stream2.getID());
     StreamCallbackWithID stream3 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     nextBuf =  ByteBuffer.wrap(expectedBytes, 2, 2);
     stream3.onData(stream3.getID(), nextBuf);
     stream3.onComplete(stream3.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}});
     FileSegmentManagedBuffer mb =
-      (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0);
+      (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0, 0);
     assertArrayEquals(expectedBytes, mb.nioByteBuffer().array());
   }
 
@@ -351,11 +352,11 @@ public class RemoteBlockPushResolverSuite {
   public void testCollision() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     // This should be deferred
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
     // Since stream2 didn't get any opportunity it will throw couldn't find opportunity error
@@ -363,7 +364,7 @@ public class RemoteBlockPushResolverSuite {
       stream2.onComplete(stream2.getID());
     } catch (RuntimeException re) {
       assertEquals(
-        "Couldn't find an opportunity to write block shufflePush_0_1_0 to merged shuffle",
+        "Couldn't find an opportunity to write block shufflePush_0_0_1_0 to merged shuffle",
         re.getMessage());
       throw re;
     }
@@ -373,16 +374,16 @@ public class RemoteBlockPushResolverSuite {
   public void testFailureInAStreamDoesNotInterfereWithStreamWhichIsWriting() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     // There is a failure with stream2
     stream2.onFailure(stream2.getID(), new RuntimeException("forced error"));
     StreamCallbackWithID stream3 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     // This should be deferred
     stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[5]));
     // Since this stream didn't get any opportunity it will throw couldn't find opportunity error
@@ -391,7 +392,7 @@ public class RemoteBlockPushResolverSuite {
       stream3.onComplete(stream3.getID());
     } catch (RuntimeException re) {
       assertEquals(
-        "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle",
+        "Couldn't find an opportunity to write block shufflePush_0_0_2_0 to merged shuffle",
         re.getMessage());
       failedEx = re;
     }
@@ -399,9 +400,9 @@ public class RemoteBlockPushResolverSuite {
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onComplete(stream1.getID());
 
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}});
     if (failedEx != null) {
       throw failedEx;
     }
@@ -502,11 +503,11 @@ public class RemoteBlockPushResolverSuite {
     Path[] activeDirs = createLocalDirs(1);
     registerExecutor(testApp, prepareLocalDirs(activeDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META);
     PushBlock[] pushBlocks = new PushBlock[] {
-      new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4]))};
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4]))};
     pushBlockHelper(testApp, NO_ATTEMPT_ID, pushBlocks);
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 0);
-    validateChunks(testApp, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 0, 0);
+    validateChunks(testApp, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}});
     String[] mergeDirs = pushResolver.getMergedBlockDirs(testApp);
     pushResolver.applicationRemoved(testApp,  true);
     // Since the cleanup happen in a different thread, check few times to see if the merge dirs gets
@@ -522,7 +523,7 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4]));
     callback1.onComplete(callback1.getID());
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo();
@@ -530,7 +531,7 @@ public class RemoteBlockPushResolverSuite {
     TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile();
     testIndexFile.close();
     StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5]));
     // This will complete without any IOExceptions because number of IOExceptions are less than
     // the threshold but the update to index file will be unsuccessful.
@@ -539,15 +540,15 @@ public class RemoteBlockPushResolverSuite {
     // Restore the index stream so it can write successfully again.
     testIndexFile.restore();
     StreamCallbackWithID callback3 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     callback3.onData(callback3.getID(), ByteBuffer.wrap(new byte[2]));
     callback3.onComplete(callback3.getID());
     assertEquals("index position", 24, testIndexFile.getPos());
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {0}, new long[] {11});
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}});
   }
 
   @Test
@@ -555,7 +556,7 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4]));
     callback1.onComplete(callback1.getID());
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo();
@@ -563,7 +564,7 @@ public class RemoteBlockPushResolverSuite {
     TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile();
     testIndexFile.close();
     StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5]));
     // This will complete without any IOExceptions because number of IOExceptions are less than
     // the threshold but the update to index file will be unsuccessful.
@@ -573,11 +574,11 @@ public class RemoteBlockPushResolverSuite {
     // Restore the index stream so it can write successfully again.
     testIndexFile.restore();
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     assertEquals("index position", 24, testIndexFile.getPos());
     validateMergeStatuses(statuses, new int[] {0}, new long[] {9});
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}});
   }
 
   @Test
@@ -585,7 +586,7 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(false, true);
     RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4]));
     callback1.onComplete(callback1.getID());
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo();
@@ -594,7 +595,7 @@ public class RemoteBlockPushResolverSuite {
     long metaPosBeforeClose = testMetaFile.getPos();
     testMetaFile.close();
     StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5]));
     // This will complete without any IOExceptions because number of IOExceptions are less than
     // the threshold but the update to index and meta file will be unsuccessful.
@@ -604,16 +605,16 @@ public class RemoteBlockPushResolverSuite {
     // Restore the meta stream so it can write successfully again.
     testMetaFile.restore();
     StreamCallbackWithID callback3 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     callback3.onData(callback3.getID(), ByteBuffer.wrap(new byte[2]));
     callback3.onComplete(callback3.getID());
     assertEquals("index position", 24, partitionInfo.getIndexFile().getPos());
     assertTrue("meta position", testMetaFile.getPos() > metaPosBeforeClose);
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {0}, new long[] {11});
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}});
   }
 
   @Test
@@ -621,7 +622,7 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(false, true);
     RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4]));
     callback1.onComplete(callback1.getID());
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo();
@@ -630,7 +631,7 @@ public class RemoteBlockPushResolverSuite {
     long metaPosBeforeClose = testMetaFile.getPos();
     testMetaFile.close();
     StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5]));
     // This will complete without any IOExceptions because number of IOExceptions are less than
     // the threshold but the update to index and meta file will be unsuccessful.
@@ -641,19 +642,19 @@ public class RemoteBlockPushResolverSuite {
     // Restore the meta stream so it can write successfully again.
     testMetaFile.restore();
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     assertEquals("index position", 24, indexFile.getPos());
     assertTrue("meta position", testMetaFile.getPos() > metaPosBeforeClose);
     validateMergeStatuses(statuses, new int[] {0}, new long[] {9});
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}});
   }
 
   @Test(expected = RuntimeException.class)
   public void testIOExceptionsExceededThreshold() throws IOException {
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo();
     callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
     callback.onComplete(callback.getID());
@@ -662,7 +663,7 @@ public class RemoteBlockPushResolverSuite {
     for (int i = 1; i < 5; i++) {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0));
+          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
       try {
         callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2]));
       } catch (IOException ioe) {
@@ -675,10 +676,10 @@ public class RemoteBlockPushResolverSuite {
     try {
       RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 5, 0, 0));
+          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0));
       callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1]));
     } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_5_0",
+      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
         t.getMessage());
       throw t;
     }
@@ -689,7 +690,7 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo();
     callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
     callback.onComplete(callback.getID());
@@ -698,7 +699,7 @@ public class RemoteBlockPushResolverSuite {
     for (int i = 1; i < 5; i++) {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0));
+          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
       callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
       // This will complete without any exceptions but the exception count is increased.
       callback1.onComplete(callback1.getID());
@@ -709,11 +710,11 @@ public class RemoteBlockPushResolverSuite {
     try {
       RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 5, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0));
       callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4]));
       callback2.onComplete(callback2.getID());
     } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_5_0",
+      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
         t.getMessage());
       throw t;
     }
@@ -728,9 +729,9 @@ public class RemoteBlockPushResolverSuite {
     }
     try {
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 10, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0));
     } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_10_0",
+      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0",
         t.getMessage());
       throw t;
     }
@@ -741,14 +742,14 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo();
     TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile();
     testIndexFile.close();
     for (int i = 1; i < 6; i++) {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0));
+          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
       try {
         callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
         // This will complete without any exceptions but the exception count is increased.
@@ -763,7 +764,7 @@ public class RemoteBlockPushResolverSuite {
     try {
       callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
     } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0",
+      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0",
         t.getMessage());
       throw t;
     }
@@ -774,14 +775,14 @@ public class RemoteBlockPushResolverSuite {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo();
     TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile();
     testIndexFile.close();
     for (int i = 1; i < 5; i++) {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0));
+          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
       try {
         callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
         // This will complete without any exceptions but the exception count is increased.
@@ -793,7 +794,7 @@ public class RemoteBlockPushResolverSuite {
     assertEquals(4, partitionInfo.getNumIOExceptions());
     RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, 1, 0, 5, 0, 0));
+        new PushBlockStream(TEST_APP, 1, 0, 0, 5, 0, 0));
     callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5]));
     // This is deferred
     callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
@@ -820,15 +821,15 @@ public class RemoteBlockPushResolverSuite {
   public void testFailureWhileTruncatingFiles() throws IOException {
     useTestFiles(true, false);
     PushBlock[] pushBlocks = new PushBlock[] {
-      new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])),
-      new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])),
-      new PushBlock(0, 0, 1, ByteBuffer.wrap(new byte[5])),
-      new PushBlock(0, 1, 1, ByteBuffer.wrap(new byte[3]))
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])),
+      new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])),
+      new PushBlock(0, 0, 0, 1, ByteBuffer.wrap(new byte[5])),
+      new PushBlock(0, 0, 1, 1, ByteBuffer.wrap(new byte[3]))
     };
     pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     callback.onData(callback.getID(), ByteBuffer.wrap(new byte[2]));
     callback.onComplete(callback.getID());
     RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo();
@@ -836,37 +837,37 @@ public class RemoteBlockPushResolverSuite {
     // Close the index file so truncate throws IOException
     testIndexFile.close();
     MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
-      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     validateMergeStatuses(statuses, new int[] {1}, new long[] {8});
-    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 1);
-    validateChunks(TEST_APP, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}});
+    MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 1);
+    validateChunks(TEST_APP, 0, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}});
   }
 
   @Test
   public void testOnFailureInvokedMoreThanOncePerBlock() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onFailure(stream1.getID(), new RuntimeException("forced error"));
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
     // On failure on stream1 gets invoked again and should cause no interference
     stream1.onFailure(stream1.getID(), new RuntimeException("2nd forced error"));
     StreamCallbackWithID stream3 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 3, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 3, 0, 0));
     // This should be deferred as stream 2 is still the active stream
     stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
     // Stream 2 writes more and completes
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4]));
     stream2.onComplete(stream2.getID());
     stream3.onComplete(stream3.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}});
     removeApplication(TEST_APP);
   }
 
@@ -874,24 +875,24 @@ public class RemoteBlockPushResolverSuite {
   public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     StreamCallbackWithID stream1Duplicate =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onComplete(stream1.getID());
     stream1Duplicate.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
 
     StreamCallbackWithID stream2 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
     // Should not change the current map id of the reduce partition
     stream1Duplicate.onFailure(stream2.getID(), new RuntimeException("forced error"));
 
     StreamCallbackWithID stream3 =
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0));
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     // This should be deferred as stream 2 is still the active stream
     stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
     RuntimeException failedEx = null;
@@ -899,16 +900,16 @@ public class RemoteBlockPushResolverSuite {
       stream3.onComplete(stream3.getID());
     } catch (RuntimeException re) {
       assertEquals(
-        "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle",
+        "Couldn't find an opportunity to write block shufflePush_0_0_2_0 to merged shuffle",
         re.getMessage());
       failedEx = re;
     }
     // Stream 2 writes more and completes
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4]));
     stream2.onComplete(stream2.getID());
-    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0));
-    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
-    validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}});
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}});
     removeApplication(TEST_APP);
     if (failedEx != null) {
       throw failedEx;
@@ -938,46 +939,42 @@ public class RemoteBlockPushResolverSuite {
       ByteBuffer.wrap(new byte[5])
     };
     StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(testApp, 1, 0, 0, 0, 0));
+      new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0));
     for (ByteBuffer block : blocks) {
       stream1.onData(stream1.getID(), block);
     }
     stream1.onComplete(stream1.getID());
     RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo =
       pushResolver.validateAndGetAppShuffleInfo(testApp);
-    Map<Integer, Map<Integer, RemoteBlockPushResolver.AppShufflePartitionInfo>> partitions =
-      appShuffleInfo.getPartitions();
-    for (Map<Integer, RemoteBlockPushResolver.AppShufflePartitionInfo> partitionMap :
-        partitions.values()) {
-      for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : partitionMap.values()) {
-        assertTrue(partitionInfo.getDataChannel().isOpen());
-        assertTrue(partitionInfo.getMetaFile().getChannel().isOpen());
-        assertTrue(partitionInfo.getIndexFile().getChannel().isOpen());
-      }
+    RemoteBlockPushResolver.AppShuffleMergePartitionsInfo partitions
+      = appShuffleInfo.getShuffles().get(0);
+    for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo :
+        partitions.getShuffleMergePartitions().values()) {
+      assertTrue(partitionInfo.getDataChannel().isOpen());
+      assertTrue(partitionInfo.getMetaFile().getChannel().isOpen());
+      assertTrue(partitionInfo.getIndexFile().getChannel().isOpen());
     }
     Path[] attempt2LocalDirs = createLocalDirs(2);
     registerExecutor(testApp,
       prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2),
       MERGE_DIRECTORY_META_2);
     StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(testApp, 2, 0, 1, 0, 0));
+      new PushBlockStream(testApp, 2, 0, 0, 1, 0, 0));
     for (ByteBuffer block : blocks) {
       stream2.onData(stream2.getID(), block);
     }
     stream2.onComplete(stream2.getID());
     closed.acquire();
     // Check if all the file channels created for the first attempt are safely closed.
-    for (Map<Integer, RemoteBlockPushResolver.AppShufflePartitionInfo> partitionMap :
-        partitions.values()) {
-      for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : partitionMap.values()) {
-        assertFalse(partitionInfo.getDataChannel().isOpen());
-        assertFalse(partitionInfo.getMetaFile().getChannel().isOpen());
-        assertFalse(partitionInfo.getIndexFile().getChannel().isOpen());
-      }
+    for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo :
+      partitions.getShuffleMergePartitions().values()) {
+      assertFalse(partitionInfo.getDataChannel().isOpen());
+      assertFalse(partitionInfo.getMetaFile().getChannel().isOpen());
+      assertFalse(partitionInfo.getIndexFile().getChannel().isOpen());
     }
     try {
       pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(testApp, 1, 0, 1, 0, 0));
+        new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0));
     } catch (IllegalArgumentException re) {
       assertEquals(
         "The attempt id 1 in this PushBlockStream message does not match " +
@@ -1000,7 +997,7 @@ public class RemoteBlockPushResolverSuite {
       ByteBuffer.wrap(new byte[5])
     };
     StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(testApp, 1, 0, 0, 0, 0));
+      new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0));
     for (ByteBuffer block : blocks) {
       stream1.onData(stream1.getID(), block);
     }
@@ -1010,7 +1007,7 @@ public class RemoteBlockPushResolverSuite {
       prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2),
       MERGE_DIRECTORY_META_2);
     try {
-      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0));
+      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0));
     } catch (IllegalArgumentException e) {
       assertEquals(e.getMessage(),
         String.format("The attempt id %s in this FinalizeShuffleMerge message does not " +
@@ -1022,13 +1019,13 @@ public class RemoteBlockPushResolverSuite {
 
   @Test(expected = ClosedChannelException.class)
   public void testOngoingMergeOfBlockFromPreviousAttemptIsAborted()
-    throws IOException, InterruptedException {
+      throws IOException, InterruptedException {
     Semaphore closed = new Semaphore(0);
     pushResolver = new RemoteBlockPushResolver(conf) {
       @Override
       void closeAndDeletePartitionFilesIfNeeded(
-        AppShuffleInfo appShuffleInfo,
-        boolean cleanupLocalDirs) {
+          AppShuffleInfo appShuffleInfo,
+          boolean cleanupLocalDirs) {
         super.closeAndDeletePartitionFilesIfNeeded(appShuffleInfo, cleanupLocalDirs);
         closed.release();
       }
@@ -1045,7 +1042,7 @@ public class RemoteBlockPushResolverSuite {
       ByteBuffer.wrap(new byte[7])
     };
     StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(
-      new PushBlockStream(testApp, 1, 0, 0, 0, 0));
+      new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0));
     // The onData callback should be called 4 times here before the onComplete callback. But a
     // register executor message arrives in shuffle service after the 2nd onData callback. The 3rd
     // onData callback should all throw ClosedChannelException as their channels are closed.
@@ -1060,17 +1057,202 @@ public class RemoteBlockPushResolverSuite {
     stream1.onData(stream1.getID(), blocks[3]);
   }
 
+  @Test
+  public void testBlockPushWithOlderShuffleMergeId() throws IOException {
+    StreamCallbackWithID stream1 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    StreamCallbackWithID stream2 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    try {
+      // stream 1 push should be rejected as it is from an older shuffleMergeId
+      stream1.onComplete(stream1.getID());
+    } catch(RuntimeException re) {
+      assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or"
+        + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being"
+          + " pushed", re.getMessage());
+    }
+    // stream 2 now completes
+    stream2.onComplete(stream2.getID());
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0);
+    validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+  }
+
+  @Test
+  public void testFinalizeWithOlderShuffleMergeId() throws IOException {
+    StreamCallbackWithID stream1 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    StreamCallbackWithID stream2 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    try {
+      // stream 1 push should be rejected as it is from an older shuffleMergeId
+      stream1.onComplete(stream1.getID());
+    } catch(RuntimeException re) {
+      assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or"
+        + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being"
+          + " pushed", re.getMessage());
+    }
+    // stream 2 now completes
+    stream2.onComplete(stream2.getID());
+    try {
+      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1));
+    } catch(RuntimeException re) {
+      assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
+        + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
+          + " is already being pushed", re.getMessage());
+    }
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
+
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0);
+    validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+  }
+
+  @Test
+  public void testFinalizeOfDeterminateShuffle() throws IOException {
+    PushBlock[] pushBlocks = new PushBlock[] {
+      new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4])),
+      new PushBlock(0,  0, 1, 0, ByteBuffer.wrap(new byte[5]))
+    };
+    pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks);
+    MergeStatuses statuses = pushResolver.finalizeShuffleMerge(
+      new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
+
+    RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo =
+      pushResolver.validateAndGetAppShuffleInfo(TEST_APP);
+    assertTrue("Metadata of determinate shuffle should be removed after finalize shuffle"
+      + " merge", appShuffleInfo.getShuffles().get(0) == null);
+    validateMergeStatuses(statuses, new int[] {0}, new long[] {9});
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}});
+  }
+
+  @Test
+  public void testBlockFetchWithOlderShuffleMergeId() throws IOException {
+    StreamCallbackWithID stream1 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    StreamCallbackWithID stream2 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    try {
+      // stream 1 push should be rejected as it is from an older shuffleMergeId
+      stream1.onComplete(stream1.getID());
+    } catch(RuntimeException re) {
+      assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or"
+        + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being"
+          + " pushed", re.getMessage());
+    }
+    // stream 2 now completes
+    stream2.onComplete(stream2.getID());
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
+    try {
+      pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    } catch(RuntimeException re) {
+      assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0"
+        + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
+        + " the shuffle is available", re.getMessage());
+    }
+
+    try {
+      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1));
+    } catch(RuntimeException re) {
+      assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
+        + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
+          + " is already being pushed", re.getMessage());
+    }
+    try {
+      pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0);
+    } catch(RuntimeException re) {
+      assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0"
+        + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
+        + " the shuffle is available", re.getMessage());
+    }
+
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0);
+    validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+  }
+
+  @Test
+  public void testCleanupOlderShuffleMergeId() throws IOException, InterruptedException {
+    Semaphore closed = new Semaphore(0);
+    pushResolver = new RemoteBlockPushResolver(conf) {
+      @Override
+      void closeAndDeletePartitionFiles(Map<Integer, AppShufflePartitionInfo> partitions) {
+        super.closeAndDeletePartitionFiles(partitions);
+        closed.release();
+      }
+    };
+    String testApp = "testCleanupOlderShuffleMergeId";
+    registerExecutor(testApp, prepareLocalDirs(localDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META);
+    StreamCallbackWithID stream1 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 1, 0, 0, 0));
+    stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
+    StreamCallbackWithID stream2 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 2, 0, 0, 0));
+    RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo =
+      pushResolver.validateAndGetAppShuffleInfo(testApp);
+    closed.acquire();
+    assertFalse("Data files on the disk should be cleaned up",
+      appShuffleInfo.getMergedShuffleDataFile(0, 1, 0).exists());
+    assertFalse("Meta files on the disk should be cleaned up",
+      appShuffleInfo.getMergedShuffleMetaFile(0, 1, 0).exists());
+    assertFalse("Index files on the disk should be cleaned up",
+      appShuffleInfo.getMergedShuffleIndexFile(0, 1, 0).exists());
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
+    // stream 2 now completes
+    stream2.onComplete(stream2.getID());
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 2));
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 2, 0);
+    validateChunks(testApp, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}});
+
+    // Check whether the metadata is cleaned up or not
+    StreamCallbackWithID stream3 =
+      pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 3, 0, 0, 0));
+    closed.acquire();
+    stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
+    stream3.onComplete(stream3.getID());
+    pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 3));
+    MergedBlockMeta mergedBlockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 3, 0);
+    validateChunks(testApp, 0, 3, 0, mergedBlockMeta, new int[]{2}, new int[][]{{0}});
+  }
+
   private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException {
     pushResolver = new RemoteBlockPushResolver(conf) {
       @Override
-      AppShufflePartitionInfo newAppShufflePartitionInfo(String appId, int shuffleId,
-          int reduceId, File dataFile, File indexFile, File metaFile) throws IOException {
+      AppShufflePartitionInfo newAppShufflePartitionInfo(
+          String appId,
+          int shuffleId,
+          int shuffleMergeId,
+          int reduceId,
+          File dataFile,
+          File indexFile,
+          File metaFile) throws IOException {
         MergeShuffleFile mergedIndexFile = useTestIndexFile ? new TestMergeShuffleFile(indexFile)
           : new MergeShuffleFile(indexFile);
         MergeShuffleFile mergedMetaFile = useTestMetaFile ? new TestMergeShuffleFile(metaFile) :
           new MergeShuffleFile(metaFile);
-        return new AppShufflePartitionInfo(appId, shuffleId, reduceId, dataFile, mergedIndexFile,
-          mergedMetaFile);
+        return new AppShufflePartitionInfo(appId, shuffleId, shuffleMergeId, reduceId, dataFile,
+          mergedIndexFile, mergedMetaFile);
       }
     };
     registerExecutor(TEST_APP, prepareLocalDirs(localDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META);
@@ -1116,6 +1298,7 @@ public class RemoteBlockPushResolverSuite {
   private void validateChunks(
       String appId,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId,
       MergedBlockMeta meta,
       int[] expectedSizes,
@@ -1129,7 +1312,8 @@ public class RemoteBlockPushResolverSuite {
     }
     for (int i = 0; i < meta.getNumChunks(); i++) {
       FileSegmentManagedBuffer mb =
-        (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(appId, shuffleId, reduceId, i);
+        (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(appId, shuffleId,
+          shuffleMergeId, reduceId, i);
       assertEquals(expectedSizes[i], mb.getLength());
     }
   }
@@ -1140,8 +1324,8 @@ public class RemoteBlockPushResolverSuite {
       PushBlock[] blocks) throws IOException {
     for (PushBlock block : blocks) {
       StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(
-          appId, attemptId, block.shuffleId, block.mapIndex, block.reduceId, 0));
+        new PushBlockStream(appId, attemptId, block.shuffleId, block.shuffleMergeId,
+          block.mapIndex, block.reduceId, 0));
       stream.onData(stream.getID(), block.buffer);
       stream.onComplete(stream.getID());
     }
@@ -1149,11 +1333,13 @@ public class RemoteBlockPushResolverSuite {
 
   private static class PushBlock {
     private final int shuffleId;
+    private final int shuffleMergeId;
     private final int mapIndex;
     private final int reduceId;
     private final ByteBuffer buffer;
-    PushBlock(int shuffleId, int mapIndex, int reduceId, ByteBuffer buffer) {
+    PushBlock(int shuffleId, int shuffleMergeId, int mapIndex, int reduceId, ByteBuffer buffer) {
       this.shuffleId = shuffleId;
+      this.shuffleMergeId = shuffleMergeId;
       this.mapIndex = mapIndex;
       this.reduceId = reduceId;
       this.buffer = buffer;
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java
index 91f319d..c79b01e 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java
@@ -29,10 +29,10 @@ public class FetchShuffleBlockChunksSuite {
   @Test
   public void testFetchShuffleBlockChunksEncodeDecode() {
     FetchShuffleBlockChunks shuffleBlockChunks =
-      new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}});
+      new FetchShuffleBlockChunks("app0", "exec1", 0, 0, new int[] {0}, new int[][] {{0, 1}});
     Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks());
     int len = shuffleBlockChunks.encodedLength();
-    Assert.assertEquals(45, len);
+    Assert.assertEquals(49, len);
     ByteBuf buf = Unpooled.buffer(len);
     shuffleBlockChunks.encode(buf);
 
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 4063d11..81e4c8f 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import scala.reflect.ClassTag
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
@@ -78,7 +79,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
     val aggregator: Option[Aggregator[K, V, C]] = None,
     val mapSideCombine: Boolean = false,
     val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
-  extends Dependency[Product2[K, V]] {
+  extends Dependency[Product2[K, V]] with Logging {
 
   if (mapSideCombine) {
     require(aggregator.isDefined, "Map-side combine without Aggregator specified!")
@@ -101,10 +102,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
 
   // By default, shuffle merge is enabled for ShuffleDependency if push based shuffle
   // is enabled
-  private[this] var _shuffleMergeEnabled =
-    Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) &&
-    // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
-    !rdd.isBarrier()
+  private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled()
 
   private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = {
     _shuffleMergeEnabled = shuffleMergeEnabled
@@ -124,6 +122,14 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
    */
   private[this] var _shuffleMergedFinalized: Boolean = false
 
+  /**
+   * shuffleMergeId is used to uniquely identify merging process of shuffle
+   * by an indeterminate stage attempt.
+   */
+  private[this] var _shuffleMergeId: Int = 0
+
+  def shuffleMergeId: Int = _shuffleMergeId
+
   def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
     if (mergerLocs != null) {
       this.mergerLocs = mergerLocs
@@ -150,6 +156,22 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
     }
   }
 
+  def newShuffleMergeState(): Unit = {
+    _shuffleMergedFinalized = false
+    mergerLocs = Nil
+    _shuffleMergeId += 1
+  }
+
+  private def canShuffleMergeBeEnabled(): Boolean = {
+    val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf)
+    if (isPushShuffleEnabled && rdd.isBarrier()) {
+      logWarning("Push-based shuffle is currently not supported for barrier stages")
+    }
+    isPushShuffleEnabled &&
+      // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
+      !rdd.isBarrier()
+  }
+
   _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
   _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
 }
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index e605eea..1b25ec5 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -38,7 +38,7 @@ import org.apache.spark.io.CompressionCodec
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
 import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
 import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
 import org.apache.spark.util._
 
 /**
@@ -1450,11 +1450,10 @@ private[spark] object MapOutputTracker extends Logging {
           val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {
             // If MergeStatus is available for the given partition, add location of the
             // pre-merged shuffle partition for this partition ID. Here we create a
-            // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is
-            // a merged shuffle block.
+            // ShuffleMergedBlockId to indicate this is a merged shuffle block.
             splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
-              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize,
-                SHUFFLE_PUSH_MAP_ID))
+              ((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, partId),
+                mergeStatus.totalSize, SHUFFLE_PUSH_MAP_ID))
             // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
             // shuffle partition blocks, fetch the original map produced shuffle partition blocks
             val mapStatusesWithIndex = mapStatuses.zipWithIndex
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 9709ec1..b276de1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1323,9 +1323,8 @@ private[spark] class DAGScheduler(
     // `findMissingPartitions()` returns all partitions every time.
     stage match {
       case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
-        // TODO: SPARK-32923: Clean all push-based shuffle metadata like merge enabled and
-        // TODO: finalized as we are clearing all the merge results.
         mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+        sms.shuffleDep.newShuffleMergeState()
       case _ =>
     }
 
@@ -2057,7 +2056,7 @@ private[spark] class DAGScheduler(
           // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled
           // TODO: during shuffleMergeFinalizeWaitSec
           shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
-            shuffleServiceLoc.port, shuffleId,
+            shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId,
             new MergeFinalizerListener {
               override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
                 assert(shuffleId == statuses.shuffleId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
index 77d8f8e..6d16026 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
@@ -43,14 +43,17 @@ import org.apache.spark.util.Utils
  */
 private[spark] class MergeStatus(
     private[this] var loc: BlockManagerId,
+    private[this] var _shuffleMergeId: Int,
     private[this] var mapTracker: RoaringBitmap,
     private[this] var size: Long)
   extends Externalizable with ShuffleOutputStatus {
 
-  protected def this() = this(null, null, -1) // For deserialization only
+  protected def this() = this(null, -1, null, -1) // For deserialization only
 
   def location: BlockManagerId = loc
 
+  def shuffleMergeId: Int = _shuffleMergeId
+
   def totalSize: Long = size
 
   def tracker: RoaringBitmap = mapTracker
@@ -73,12 +76,14 @@ private[spark] class MergeStatus(
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     loc.writeExternal(out)
+    out.writeInt(_shuffleMergeId)
     mapTracker.writeExternal(out)
     out.writeLong(size)
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     loc = BlockManagerId(in)
+    _shuffleMergeId = in.readInt()
     mapTracker = new RoaringBitmap()
     mapTracker.readExternal(in)
     size = in.readLong()
@@ -100,14 +105,20 @@ private[spark] object MergeStatus {
     assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length &&
       mergeStatuses.bitmaps.length == mergeStatuses.sizes.length)
     val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port)
+    val shuffleMergeId = mergeStatuses.shuffleMergeId
     mergeStatuses.bitmaps.zipWithIndex.map {
       case (bitmap, index) =>
-        val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index))
+        val mergeStatus = new MergeStatus(mergerLoc, shuffleMergeId, bitmap,
+          mergeStatuses.sizes(index))
         (mergeStatuses.reduceIds(index), mergeStatus)
     }
   }
 
-  def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = {
-    new MergeStatus(loc, bitmap, size)
+  def apply(
+      loc: BlockManagerId,
+      shuffleMergeId: Int,
+      bitmap: RoaringBitmap,
+      size: Long): MergeStatus = {
+    new MergeStatus(loc, shuffleMergeId, bitmap, size)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 9c50569..07928f8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -116,28 +116,31 @@ private[spark] class IndexShuffleBlockResolver(
   private def getMergedBlockDataFile(
       appId: String,
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       dirs: Option[Array[String]] = None): File = {
     blockManager.diskBlockManager.getMergedShuffleFile(
-      ShuffleMergedDataBlockId(appId, shuffleId, reduceId), dirs)
+      ShuffleMergedDataBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
   }
 
   private def getMergedBlockIndexFile(
       appId: String,
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       dirs: Option[Array[String]] = None): File = {
     blockManager.diskBlockManager.getMergedShuffleFile(
-      ShuffleMergedIndexBlockId(appId, shuffleId, reduceId), dirs)
+      ShuffleMergedIndexBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
   }
 
   private def getMergedBlockMetaFile(
       appId: String,
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       dirs: Option[Array[String]] = None): File = {
     blockManager.diskBlockManager.getMergedShuffleFile(
-      ShuffleMergedMetaBlockId(appId, shuffleId, reduceId), dirs)
+      ShuffleMergedMetaBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
   }
 
   /**
@@ -466,11 +469,13 @@ private[spark] class IndexShuffleBlockResolver(
    * knows how to consume local merged shuffle file as multiple chunks.
    */
   override def getMergedBlockData(
-      blockId: ShuffleBlockId,
+      blockId: ShuffleMergedBlockId,
       dirs: Option[Array[String]]): Seq[ManagedBuffer] = {
     val indexFile =
-      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
-    val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
+      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.shuffleMergeId,
+        blockId.reduceId, dirs)
+    val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId,
+      blockId.shuffleMergeId, blockId.reduceId, dirs)
     // Load all the indexes in order to identify all chunks in the specified merged shuffle file.
     val size = indexFile.length.toInt
     val offsets = Utils.tryWithResource {
@@ -493,13 +498,15 @@ private[spark] class IndexShuffleBlockResolver(
    * This is only used for reading local merged block meta data.
    */
   override def getMergedBlockMeta(
-      blockId: ShuffleBlockId,
+      blockId: ShuffleMergedBlockId,
       dirs: Option[Array[String]]): MergedBlockMeta = {
     val indexFile =
-      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
+      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId,
+        blockId.shuffleMergeId, blockId.reduceId, dirs)
     val size = indexFile.length.toInt
     val numChunks = (size / 8) - 1
-    val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
+    val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId,
+      blockId.shuffleMergeId, blockId.reduceId, dirs)
     val chunkBitMaps = new FileSegmentManagedBuffer(transportConf, metaFile, 0L, metaFile.length)
     new MergedBlockMeta(numChunks, chunkBitMaps)
   }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
index 0d2462f..56f915b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -69,7 +69,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
     new BlockPushErrorHandler() {
       // For a connection exception against a particular host, we will stop pushing any
       // blocks to just that host and continue push blocks to other hosts. So, here push of
-      // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore.
+      // all blocks will only stop when it is "Too Late" or "Invalid Block push.
+      // Also see updateStateAndCheckIfPushMore.
       override def shouldRetryError(t: Throwable): Boolean = {
         // If it is a FileNotFoundException originating from the client while pushing the shuffle
         // blocks to the server, then we stop pushing all the blocks because this indicates the
@@ -77,8 +78,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
         if (t.getCause != null && t.getCause.isInstanceOf[FileNotFoundException]) {
           return false
         }
-        // If the block is too late, there is no need to retry it
-        !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+        val errorStackTraceString = Throwables.getStackTraceAsString(t)
+        // If the block is too late or the invalid block push, there is no need to retry it
+        !errorStackTraceString.contains(
+          BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)
       }
     }
   }
@@ -99,8 +102,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
       mapIndex: Int): Unit = {
     val numPartitions = dep.partitioner.numPartitions
     val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
-    val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile,
-      partitionLengths, dep.getMergerLocs, transportConf)
+    val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId,
+      dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, transportConf)
     // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
     // time won't be pushing the same ranges of shuffle partitions.
     pushRequests ++= Utils.randomize(requests)
@@ -335,6 +338,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
    * @param numPartitions number of shuffle partitions in the shuffle file
    * @param partitionId map index of the current mapper
    * @param shuffleId shuffleId of current shuffle
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param dataFile shuffle data file
    * @param partitionLengths array of sizes of blocks in the shuffle data file
    * @param mergerLocs target locations to push blocks to
@@ -347,6 +352,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
       numPartitions: Int,
       partitionId: Int,
       shuffleId: Int,
+      shuffleMergeId: Int,
       dataFile: File,
       partitionLengths: Array[Long],
       mergerLocs: Seq[BlockManagerId],
@@ -361,7 +367,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
     for (reduceId <- 0 until numPartitions) {
       val blockSize = partitionLengths(reduceId)
       logDebug(
-        s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize")
+        s"Block ${ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
+          reduceId)} is of size $blockSize")
       // Skip 0-length blocks and blocks that are large enough
       if (blockSize > 0) {
         val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
@@ -394,7 +401,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
         // Only push blocks under the size limit
         if (blockSize <= maxBlockSizeToPush) {
           val blockSizeInt = blockSize.toInt
-          blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt))
+          blocks += ((ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
+            reduceId), blockSizeInt))
           // Only update currentReqOffset if the current block is the first in the request
           if (currentReqOffset == -1) {
             currentReqOffset = offset
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index 49e5929..0f35f8c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -19,7 +19,7 @@ package org.apache.spark.shuffle
 
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.MergedBlockMeta
-import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, ShuffleMergedBlockId}
 
 private[spark]
 /**
@@ -44,12 +44,16 @@ trait ShuffleBlockResolver {
   /**
    * Retrieve the data for the specified merged shuffle block as multiple chunks.
    */
-  def getMergedBlockData(blockId: ShuffleBlockId, dirs: Option[Array[String]]): Seq[ManagedBuffer]
+  def getMergedBlockData(
+      blockId: ShuffleMergedBlockId,
+      dirs: Option[Array[String]]): Seq[ManagedBuffer]
 
   /**
    * Retrieve the meta data for the specified merged shuffle block.
    */
-  def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: Option[Array[String]]): MergedBlockMeta
+  def getMergedBlockMeta(
+      blockId: ShuffleMergedBlockId,
+      dirs: Option[Array[String]]): MergedBlockMeta
 
   def stop(): Unit
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index db5862d..ce53f08 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -77,9 +77,11 @@ case class ShuffleBlockBatchId(
 @DeveloperApi
 case class ShuffleBlockChunkId(
     shuffleId: Int,
+    shuffleMergeId: Int,
     reduceId: Int,
     chunkId: Int) extends BlockId {
-  override def name: String = "shuffleChunk_" + shuffleId  + "_" + reduceId + "_" + chunkId
+  override def name: String =
+    "shuffleChunk_" + shuffleId  + "_" + shuffleMergeId + "_" + reduceId + "_" + chunkId
 }
 
 @DeveloperApi
@@ -100,15 +102,34 @@ case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) ex
 
 @Since("3.2.0")
 @DeveloperApi
-case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
-  override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
+case class ShufflePushBlockId(
+    shuffleId: Int,
+    shuffleMergeId: Int,
+    mapIndex: Int,
+    reduceId: Int) extends BlockId {
+  override def name: String = "shufflePush_" + shuffleId + "_" +
+    shuffleMergeId + "_" + mapIndex + "_" + reduceId + ""
 }
 
 @Since("3.2.0")
 @DeveloperApi
-case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int) extends BlockId {
+case class ShuffleMergedBlockId(
+    shuffleId: Int,
+    shuffleMergeId: Int,
+    reduceId: Int) extends BlockId {
+  override def name: String = "shuffleMerged_" + shuffleId + "_" +
+    shuffleMergeId + "_" + reduceId
+}
+
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleMergedDataBlockId(
+    appId: String,
+    shuffleId: Int,
+    shuffleMergeId: Int,
+    reduceId: Int) extends BlockId {
   override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
-    appId + "_" + shuffleId + "_" + reduceId + ".data"
+    appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".data"
 }
 
 @Since("3.2.0")
@@ -116,9 +137,10 @@ case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int
 case class ShuffleMergedIndexBlockId(
     appId: String,
     shuffleId: Int,
+    shuffleMergeId: Int,
     reduceId: Int) extends BlockId {
   override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
-    appId + "_" + shuffleId + "_" + reduceId + ".index"
+    appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".index"
 }
 
 @Since("3.2.0")
@@ -126,9 +148,10 @@ case class ShuffleMergedIndexBlockId(
 case class ShuffleMergedMetaBlockId(
     appId: String,
     shuffleId: Int,
+    shuffleMergeId: Int,
     reduceId: Int) extends BlockId {
   override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
-    appId + "_" + shuffleId + "_" + reduceId + ".meta"
+    appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".meta"
 }
 
 @DeveloperApi
@@ -172,11 +195,15 @@ object BlockId {
   val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
   val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
-  val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
-  val SHUFFLE_MERGED_DATA = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r
-  val SHUFFLE_MERGED_INDEX = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r
-  val SHUFFLE_MERGED_META = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r
-  val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_MERGED = "shuffleMerged_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_MERGED_DATA =
+    "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).data".r
+  val SHUFFLE_MERGED_INDEX =
+    "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).index".r
+  val SHUFFLE_MERGED_META =
+    "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).meta".r
+  val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -195,16 +222,22 @@ object BlockId {
       ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
     case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
       ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
-    case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
-      ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
-    case SHUFFLE_MERGED_DATA(appId, shuffleId, reduceId) =>
-      ShuffleMergedDataBlockId(appId, shuffleId.toInt, reduceId.toInt)
-    case SHUFFLE_MERGED_INDEX(appId, shuffleId, reduceId) =>
-      ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt)
-    case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) =>
-      ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt)
-    case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) =>
-      ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt)
+    case SHUFFLE_PUSH(shuffleId, shuffleMergeId, mapIndex, reduceId) =>
+      ShufflePushBlockId(shuffleId.toInt, shuffleMergeId.toInt, mapIndex.toInt,
+        reduceId.toInt)
+    case SHUFFLE_MERGED(shuffleId, shuffleMergeId, reduceId) =>
+      ShuffleMergedBlockId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt)
+    case SHUFFLE_MERGED_DATA(appId, shuffleId, shuffleMergeId, reduceId) =>
+      ShuffleMergedDataBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt)
+    case SHUFFLE_MERGED_INDEX(appId, shuffleId, shuffleMergeId, reduceId) =>
+      ShuffleMergedIndexBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt,
+        reduceId.toInt)
+    case SHUFFLE_MERGED_META(appId, shuffleId, shuffleMergeId, reduceId) =>
+      ShuffleMergedMetaBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt,
+        reduceId.toInt)
+    case SHUFFLE_CHUNK(shuffleId, shuffleMergeId, reduceId, chunkId) =>
+      ShuffleBlockChunkId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt,
+        chunkId.toInt)
     case BROADCAST(broadcastId, field) =>
       BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
     case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 876ac54..b81b3b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -750,7 +750,7 @@ private[spark] class BlockManager(
    * which will be memory efficient when performing certain operations.
    */
   def getLocalMergedBlockData(
-      blockId: ShuffleBlockId,
+      blockId: ShuffleMergedBlockId,
       dirs: Array[String]): Seq[ManagedBuffer] = {
     shuffleManager.shuffleBlockResolver.getMergedBlockData(blockId, Some(dirs))
   }
@@ -759,7 +759,7 @@ private[spark] class BlockManager(
    * Get the local merged shuffle block meta data for the given block ID.
    */
   def getLocalMergedBlockMeta(
-      blockId: ShuffleBlockId,
+      blockId: ShuffleMergedBlockId,
       dirs: Array[String]): MergedBlockMeta = {
     shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs))
   }
diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
index 096ea24..99138b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
+++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
@@ -110,13 +110,14 @@ private class PushBasedFetchHelper(
    */
   def createChunkBlockInfosFromMetaResponse(
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       blockSize: Long,
       bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
     val approxChunkSize = blockSize / bitmaps.length
     val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
     for (i <- bitmaps.indices) {
-      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i)
       chunksMetaMap.put(blockChunkId, bitmaps(i))
       logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
       blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
@@ -134,37 +135,41 @@ private class PushBasedFetchHelper(
   def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
     val sizeMap = req.blocks.map {
       case FetchBlockInfo(blockId, size, _) =>
-        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId]
         ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
     }.toMap
     val address = req.address
     val mergedBlocksMetaListener = new MergedBlocksMetaListener {
-      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
-        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
-          s"from ${req.address.host}:${req.address.port}")
+      override def onSuccess(shuffleId: Int, shuffleMergeId: Int, reduceId: Int,
+          meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," +
+          s" $reduceId) from ${req.address.host}:${req.address.port}")
         try {
-          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
-            sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId,
+            reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
         } catch {
           case exception: Exception =>
             logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
-              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+              s"$shuffleMergeId, $reduceId) from" +
+              s" ${req.address.host}:${req.address.port}", exception)
             iterator.addToResultsQueue(
-              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId,
+                address))
         }
       }
 
-      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+      override def onFailure(shuffleId: Int, shuffleMergeId: Int, reduceId: Int,
+          exception: Throwable): Unit = {
         logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
           s"from ${req.address.host}:${req.address.port}", exception)
         iterator.addToResultsQueue(
-          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address))
       }
     }
     req.blocks.foreach { block =>
-      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId]
       shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
-        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+        shuffleBlockId.shuffleMergeId, shuffleBlockId.reduceId, mergedBlocksMetaListener)
     }
   }
 
@@ -241,11 +246,11 @@ private class PushBasedFetchHelper(
       localDirs: Array[String],
       blockManagerId: BlockManagerId): Unit = {
     try {
-      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId]
       val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
       iterator.addToResultsQueue(PushMergedLocalMetaFetchResult(
-        shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(),
-        localDirs))
+        shuffleBlockId.shuffleId, shuffleBlockId.shuffleMergeId,
+        shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), localDirs))
     } catch {
       case e: Exception =>
         // If we see an exception with reading a push-merged-local meta, we fallback to
@@ -283,13 +288,13 @@ private class PushBasedFetchHelper(
   def initiateFallbackFetchForPushMergedBlock(
       blockId: BlockId,
       address: BlockManagerId): Unit = {
-    assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
+    assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
     logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId")
     // Increase the blocks processed since we will process another block in the next iteration of
     // the while loop in ShuffleBlockFetcherIterator.next().
     val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
       blockId match {
-        case shuffleBlockId: ShuffleBlockId =>
+        case shuffleBlockId: ShuffleMergedBlockId =>
           iterator.decreaseNumBlocksToFetch(1)
           mapOutputTracker.getMapSizesForMergeResult(
             shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d03f20a..fd87f5e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -490,14 +490,14 @@ final class ShuffleBlockFetcherIterator(
         // Either all blocks are push-merged blocks, shuffle chunks, or original blocks.
         // Based on these types, we decide to do batch fetch and create FetchRequests with
         // forMergedMetas set.
-        case ShuffleBlockChunkId(_, _, _) =>
+        case ShuffleBlockChunkId(_, _, _, _) =>
           if (curRequestSize >= targetRemoteRequestSize ||
             curBlocks.size >= maxBlocksInFlightPerAddress) {
             curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
               collectedRemoteRequests, enableBatchFetch = false)
             curRequestSize = curBlocks.map(_.size).sum
           }
-        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+        case ShuffleMergedBlockId(_, _, _) =>
           if (curBlocks.size >= maxBlocksInFlightPerAddress) {
             curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
               collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
@@ -516,8 +516,8 @@ final class ShuffleBlockFetcherIterator(
     if (curBlocks.nonEmpty) {
       val (enableBatchFetch, forMergedMetas) = {
         curBlocks.head.blockId match {
-          case ShuffleBlockChunkId(_, _, _) => (false, false)
-          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case ShuffleBlockChunkId(_, _, _, _) => (false, false)
+          case ShuffleMergedBlockId(_, _, _) => (false, true)
           case _ => (doBatchFetch, false)
         }
       }
@@ -901,9 +901,10 @@ final class ShuffleBlockFetcherIterator(
           // a SuccessFetchResult or a FailureFetchResult.
           result = null
 
-          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs) =>
+          case PushMergedLocalMetaFetchResult(
+            shuffleId, shuffleMergeId, reduceId, bitmaps, localDirs) =>
             // Fetch push-merged-local shuffle block data as multiple shuffle chunks
-            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId)
             try {
               val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
                 localDirs)
@@ -915,7 +916,8 @@ final class ShuffleBlockFetcherIterator(
               numBlocksToFetch += bufs.size
               bufs.zipWithIndex.foreach { case (buf, chunkId) =>
                 buf.retain()
-                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId,
+                  chunkId)
                 pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
                 results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
                   pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
@@ -933,27 +935,29 @@ final class ShuffleBlockFetcherIterator(
             }
             result = null
 
-        case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address) =>
+        case PushMergedRemoteMetaFetchResult(
+          shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps, address) =>
           // The original meta request is processed so we decrease numBlocksToFetch and
           // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the
           // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks.
           numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
           numBlocksToFetch -= 1
           val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
-            shuffleId, reduceId, blockSize, bitmaps)
+            shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps)
           val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
           collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs)
           fetchRequests ++= additionalRemoteReqs
           // Set result to null to force another iteration.
           result = null
 
-        case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address) =>
+        case PushMergedRemoteMetaFailedFetchResult(
+          shuffleId, shuffleMergeId, reduceId, address) =>
           // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1.
           numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
           // If we fail to fetch the meta of a push-merged block, we fall back to fetching the
           // original blocks.
           pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
-            ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address)
+            ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), address)
           // Set result to null to force another iteration.
           result = null
       }
@@ -1421,6 +1425,8 @@ object ShuffleBlockFetcherIterator {
    * Result of a successful fetch of meta information for a remote push-merged block.
    *
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param blockSize size of each push-merged block.
    * @param bitmaps bitmaps for every chunk.
@@ -1428,6 +1434,7 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage] case class PushMergedRemoteMetaFetchResult(
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       blockSize: Long,
       bitmaps: Array[RoaringBitmap],
@@ -1437,11 +1444,14 @@ object ShuffleBlockFetcherIterator {
    * Result of a failure while fetching the meta information for a remote push-merged block.
    *
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param address BlockManager that the meta was fetched from.
    */
   private[storage] case class PushMergedRemoteMetaFailedFetchResult(
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       address: BlockManagerId) extends FetchResult
 
@@ -1449,12 +1459,15 @@ object ShuffleBlockFetcherIterator {
    * Result of a successful fetch of meta information for a push-merged-local block.
    *
    * @param shuffleId shuffle id.
+   * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+   *                       of shuffle by an indeterminate stage attempt.
    * @param reduceId reduce id.
    * @param bitmaps bitmaps for every chunk.
    * @param localDirs local directories where the push-merged shuffle files are storedl
    */
   private[storage] case class PushMergedLocalMetaFetchResult(
       shuffleId: Int,
+      shuffleMergeId: Int,
       reduceId: Int,
       bitmaps: Array[RoaringBitmap],
       localDirs: Array[String]) extends FetchResult
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index f4b47e2..69cc8c1 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Tests.IS_TESTING
 import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
 import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus}
 import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
 
 class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
   private val conf = new SparkConf
@@ -347,9 +347,9 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     bitmap.add(0)
     bitmap.add(1)
 
-    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
       bitmap, 1000L))
-    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0,
       bitmap, 1000L))
     assert(tracker.getNumAvailableMergeResults(10) == 2)
     tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
@@ -386,12 +386,12 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
     masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
 
-    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
       bitmap, 3000L))
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
-      Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1),
+      Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1),
         (ShuffleBlockId(10, 2, 0), size1000, 2)))))
 
     masterTracker.stop()
@@ -431,7 +431,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
     masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
 
-    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
       bitmap, 4000L))
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
@@ -523,7 +523,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     bitmap80.add(2)
     bitmap80.add(3)
     bitmap80.add(4)
-    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
       bitmap80, 11))
 
     val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
@@ -535,7 +535,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     // Prepare another MergeStatus that merges only 1 out of 5 blocks
     val bitmap20 = new RoaringBitmap()
     bitmap20.add(0)
-    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
       bitmap20, 2))
 
     val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
@@ -612,7 +612,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
         masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
           BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
       }
-      masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000),
+      masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), 0,
         bitmap1, 1000L))
 
       val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf))
@@ -646,13 +646,13 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     val bitmap1 = new RoaringBitmap()
     bitmap1.add(0)
     bitmap1.add(1)
-    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
       bitmap1, 1000L))
 
     val bitmap2 = new RoaringBitmap()
     bitmap2.add(5)
     bitmap2.add(6)
-    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0,
       bitmap2, 1000L))
     assert(tracker.getNumAvailableMergeResults(10) == 2)
     tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index cc09e02..312d1f8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -315,7 +315,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
         shuffleMapStage: ShuffleMapStage): Unit = {
       if (shuffleMergeRegister) {
         for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) {
-          val mergeStatuses = Seq((part, makeMergeStatus("")))
+          val mergeStatuses = Seq((part, makeMergeStatus("",
+            shuffleMapStage.shuffleDep.shuffleMergeId)))
           handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
         }
         if (shuffleMergeFinalize) {
@@ -3726,9 +3727,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
         (Success, makeMapStatus("hostA", parts))
     }.toSeq)
     val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
-    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA"))))
+    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA",
+      shuffleDep.shuffleMergeId))))
     scheduler.handleShuffleMergeFinalized(shuffleMapStage)
-    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA"))))
+    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA",
+      shuffleDep.shuffleMergeId))))
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 1)
   }
 
@@ -3779,6 +3782,71 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
     assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == parts)
   }
 
+  test("SPARK-32923: handle stage failure for indeterminate map stage with push-based shuffle") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
+    val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
+
+    // Check status for all failedStages
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.map(_.id) == Seq(1, 2))
+    // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
+    }.head.findMissingPartitions() == Seq(0))
+    // The result stage is still waiting for its 2 tasks to complete
+    assert(failedStages.collect {
+      case stage: ResultStage => stage
+    }.head.findMissingPartitions() == Seq(0, 1))
+    // shuffleMergeId for indeterminate stages would start from 1
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId
+    }.forall(x => x == 1))
+    scheduler.resubmitFailedStages()
+
+    // The first task of the `shuffleMapRdd2` failed with fetch failure
+    runEvent(makeCompletionEvent(
+      taskSets(3).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"),
+      null))
+
+    val newFailedStages = scheduler.failedStages.toSeq
+    assert(newFailedStages.map(_.id) == Seq(0, 1))
+    // shuffleMergeId for indeterminate failed stages should be 2
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId
+    }.forall(x => x == 2))
+    scheduler.resubmitFailedStages()
+
+    // First shuffle map stage resubmitted and reran all tasks.
+    assert(taskSets(4).stageId == 0)
+    assert(taskSets(4).stageAttemptId == 1)
+    assert(taskSets(4).tasks.length == 2)
+
+    // Finish all stage.
+    completeShuffleMapStageSuccessfully(0, 1, 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+    // shuffleMergeId should be 2 for the attempt number 1 for stage 0
+    assert(mapOutputTracker.shuffleStatuses.get(shuffleId1).forall(
+      _.mergeStatuses.forall(x => x.shuffleMergeId == 2)))
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId1) == 2)
+
+    completeShuffleMapStageSuccessfully(1, 2, 2, Seq("hostC", "hostD"))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+    // shuffleMergeId should be 2 for the attempt number 2 for stage 1
+    assert(mapOutputTracker.shuffleStatuses.get(shuffleId2).forall(
+      _.mergeStatuses.forall(x => x.shuffleMergeId == 3)))
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId2) == 2)
+
+    complete(taskSets(6), Seq((Success, 11), (Success, 12)))
+
+    // Job successful ended.
+    assert(results === Map(0 -> 11, 1 -> 12))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
    * Note that this checks only the host and not the executor ID.
@@ -3843,8 +3911,8 @@ object DAGSchedulerSuite {
     BlockManagerId(host + "-exec", host, 12345)
   }
 
-  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
-    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+  def makeMergeStatus(host: String, shuffleMergeId: Int, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), shuffleMergeId, mock(classOf[RoaringBitmap]), size)
 
   def addMergerLocs(locs: Seq[String]): Unit = {
     locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
index 2800be1..26cdad8 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -96,7 +96,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
     val largeBlockSize = 2 * 1024 * 1024
-    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
       mock(classOf[TransportConf]))
     assert(pushRequests.length == 3)
@@ -107,7 +107,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
-    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
     assert(pushRequests.length == 2)
     verifyPushRequests(pushRequests, Seq(6, 1024))
@@ -117,7 +117,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
-    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
     assert(pushRequests.length == 5)
     verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
@@ -220,7 +220,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     val errorHandler = pusher.createErrorHandler()
     assert(
       !errorHandler.shouldRetryError(new RuntimeException(
-        new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+        new IllegalArgumentException(
+          BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))))
     assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException())))
     assert(
       errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
@@ -233,7 +234,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     val errorHandler = pusher.createErrorHandler()
     assert(
       !errorHandler.shouldLogError(new RuntimeException(
-        new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+        new IllegalArgumentException(
+          BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))))
     assert(!errorHandler.shouldLogError(new RuntimeException(
       new IllegalArgumentException(
         BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
@@ -284,7 +286,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
             failBlock = false
             // Fail the first block with the too late exception.
             blockPushListener.onBlockPushFailure(blockId, new RuntimeException(
-              new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))
+              new IllegalArgumentException(
+                BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))
           } else {
             pushedBlocks += blockId
             blockPushListener.onBlockPushSuccess(blockId, mock(classOf[ManagedBuffer]))
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index 49c079c..abe2b56 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -27,7 +27,7 @@ import org.mockito.invocation.InvocationOnMock
 import org.roaringbitmap.RoaringBitmap
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.internal.config
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
 import org.apache.spark.storage._
@@ -172,8 +172,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
 
   test("getMergedBlockData should return expected FileSegmentManagedBuffer list") {
     val shuffleId = 1
+    val shuffleMergeId = 0
     val reduceId = 1
-    val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.data"
+    val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.data"
     val dataFile = new File(tempDir.getAbsolutePath, dataFileName)
     val out = new FileOutputStream(dataFile)
     Utils.tryWithSafeFinally {
@@ -181,12 +182,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     } {
       out.close()
     }
-    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
+    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index"
     generateMergedShuffleIndexFile(indexFileName)
     val resolver = new IndexShuffleBlockResolver(conf, blockManager)
     val dirs = Some(Array[String](tempDir.getAbsolutePath))
     val managedBufferList =
-      resolver.getMergedBlockData(ShuffleBlockId(shuffleId, -1, reduceId), dirs)
+      resolver.getMergedBlockData(ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId),
+        dirs)
     assert(managedBufferList.size === 3)
     assert(managedBufferList(0).size === 10)
     assert(managedBufferList(1).size === 0)
@@ -195,8 +197,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
 
   test("getMergedBlockMeta should return expected MergedBlockMeta") {
     val shuffleId = 1
+    val shuffleMergeId = 0
     val reduceId = 1
-    val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.meta"
+    val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.meta"
     val metaFile = new File(tempDir.getAbsolutePath, metaFileName)
     val chunkTracker = new RoaringBitmap()
     val metaFileOutputStream = new FileOutputStream(metaFile)
@@ -216,13 +219,14 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     }{
       outMeta.close()
     }
-    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
+    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index"
     generateMergedShuffleIndexFile(indexFileName)
     val resolver = new IndexShuffleBlockResolver(conf, blockManager)
     val dirs = Some(Array[String](tempDir.getAbsolutePath))
     val mergedBlockMeta =
       resolver.getMergedBlockMeta(
-        ShuffleBlockId(shuffleId, MapOutputTracker.SHUFFLE_PUSH_MAP_ID, reduceId), dirs)
+        ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId),
+        dirs)
     assert(mergedBlockMeta.getNumChunks === 3)
     assert(mergedBlockMeta.readChunkBitmaps().size === 3)
     assert(mergedBlockMeta.readChunkBitmaps()(0).contains(1))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index e8c3c2d..2fb8fa4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -105,37 +105,52 @@ class BlockIdSuite extends SparkFunSuite {
   }
 
   test("shuffle merged data") {
-    val id = ShuffleMergedDataBlockId("app_000", 8, 9)
-    assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 9))
-    assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 9))
-    assert(id.name === "shuffleMerged_app_000_8_9.data")
+    val id = ShuffleMergedDataBlockId("app_000", 8, 0, 9)
+    assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 0, 9))
+    assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 0, 9))
+    assert(id.name === "shuffleMerged_app_000_8_0_9.data")
     assert(id.asRDDId === None)
     assert(id.appId === "app_000")
+    assert(id.shuffleMergeId == 0)
     assert(id.shuffleId=== 8)
     assert(id.reduceId === 9)
     assertSame(id, BlockId(id.toString))
   }
 
   test("shuffle merged index") {
-    val id = ShuffleMergedIndexBlockId("app_000", 8, 9)
-    assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 9))
-    assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 9))
-    assert(id.name === "shuffleMerged_app_000_8_9.index")
+    val id = ShuffleMergedIndexBlockId("app_000", 8, 0, 9)
+    assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 0, 9))
+    assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 0, 9))
+    assert(id.name === "shuffleMerged_app_000_8_0_9.index")
     assert(id.asRDDId === None)
     assert(id.appId === "app_000")
     assert(id.shuffleId=== 8)
+    assert(id.shuffleMergeId == 0)
     assert(id.reduceId === 9)
     assertSame(id, BlockId(id.toString))
   }
 
   test("shuffle merged meta") {
-    val id = ShuffleMergedMetaBlockId("app_000", 8, 9)
-    assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 9))
-    assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 9))
-    assert(id.name === "shuffleMerged_app_000_8_9.meta")
+    val id = ShuffleMergedMetaBlockId("app_000", 8, 0, 9)
+    assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 0, 9))
+    assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 0, 9))
+    assert(id.name === "shuffleMerged_app_000_8_0_9.meta")
     assert(id.asRDDId === None)
     assert(id.appId === "app_000")
     assert(id.shuffleId=== 8)
+    assert(id.shuffleMergeId == 0)
+    assert(id.reduceId === 9)
+    assertSame(id, BlockId(id.toString))
+  }
+
+  test("shuffle merged block") {
+    val id = ShuffleMergedBlockId(8, 0, 9)
+    assertSame(id, ShuffleMergedBlockId(8, 0, 9))
+    assertDifferent(id, ShuffleMergedBlockId(8, 1, 9))
+    assert(id.name === "shuffleMerged_8_0_9")
+    assert(id.asRDDId === None)
+    assert(id.shuffleId=== 8)
+    assert(id.shuffleMergeId == 0)
     assert(id.reduceId === 9)
     assertSame(id, BlockId(id.toString))
   }
@@ -224,10 +239,10 @@ class BlockIdSuite extends SparkFunSuite {
   }
 
   test("shuffle chunk") {
-    val id = ShuffleBlockChunkId(1, 1, 0)
-    assertSame(id, ShuffleBlockChunkId(1, 1, 0))
-    assertDifferent(id, ShuffleBlockChunkId(1, 1, 1))
-    assert(id.name === "shuffleChunk_1_1_0")
+    val id = ShuffleBlockChunkId(1, 0, 1, 0)
+    assertSame(id, ShuffleBlockChunkId(1, 0, 1, 0))
+    assertDifferent(id, ShuffleBlockChunkId(1, 0, 1, 1))
+    assert(id.name === "shuffleChunk_1_0_1_0")
     assert(id.asRDDId === None)
     assert(id.shuffleId === 1)
     assert(id.reduceId === 1)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index a5143cd..c22e1d0 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -1059,15 +1059,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   test("SPARK-32922: fetch remote push-merged block meta") {
     val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)),
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
+          SHUFFLE_PUSH_MAP_ID)),
       (BlockManagerId("remote-client-1", "remote-host-1", 1),
         toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))
     )
     val blockChunks = Map[BlockId, ManagedBuffer](
       ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer()
+      ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer()
     )
     val blocksSem = new Semaphore(0)
     configureMockTransferForPushShuffle(blocksSem, blockChunks)
@@ -1078,17 +1079,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
     val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap)
     when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         Future {
           val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-          val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+          val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+          val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
           logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
             s"port = ${invocation.getArguments()(1)}, " +
-            s"shuffleId = $shuffleId, reduceId = $reduceId")
+            s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
           metaSem.acquire()
-          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+          metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
         }
       })
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@@ -1101,10 +1103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     metaSem.release()
     val (id3, _) = iterator.next()
     blocksSem.acquire()
-    assert(id3 === ShuffleBlockChunkId(0, 2, 0))
+    assert(id3 === ShuffleBlockChunkId(0, 0, 2, 0))
     val (id4, _) = iterator.next()
     blocksSem.acquire()
-    assert(id4 === ShuffleBlockChunkId(0, 2, 1))
+    assert(id4 === ShuffleBlockChunkId(0, 0, 2, 1))
     assert(!iterator.hasNext)
   }
 
@@ -1113,7 +1115,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
     val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)),
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
+          SHUFFLE_PUSH_MAP_ID)),
       (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)))
 
     val blockChunks = Map[BlockId, ManagedBuffer](
@@ -1127,13 +1130,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
         Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator)
     val blocksSem = new Semaphore(0)
     configureMockTransferForPushShuffle(blocksSem, blockChunks)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
-          metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error"))
+          metaListener.onFailure(shuffleId, shuffleMergeId, reduceId,
+            new RuntimeException("forced error"))
         }
       })
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@@ -1154,7 +1159,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
     val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
+          SHUFFLE_PUSH_MAP_ID)))
 
     val blockChunks = Map[BlockId, ManagedBuffer](
       ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
@@ -1165,13 +1171,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
         Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator)
     val blocksSem = new Semaphore(0)
     configureMockTransferForPushShuffle(blocksSem, blockChunks)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
-          metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error"))
+          metaListener.onFailure(shuffleId, shuffleMergeId, reduceId,
+            new RuntimeException("forced error"))
         }
       })
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@@ -1225,7 +1233,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
 
     val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER)
     doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData)
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), dirsForMergedData)
 
     // Get a valid chunk meta for this test
     val bitmaps = Array(new RoaringBitmap)
@@ -1236,7 +1244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     } else {
       createMockPushMergedBlockMeta(bitmaps.length, bitmaps)
     }
-    when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
+    when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2),
       dirsForMergedData)).thenReturn(pushMergedBlockMeta)
     when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
       Seq((localBmId,
@@ -1248,7 +1256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)),
       (pushMergedBmId, toBlockList(
-        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
+        Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
   }
 
   private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = {
@@ -1270,7 +1278,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksByAddress = prepareForFallbackToLocalBlocks(
       blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
     doThrow(new RuntimeException("Forced error")).when(blockManager)
-      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+      .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs)
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
       blockManager = Some(blockManager))
     verifyLocalBlocksFromFallback(iterator)
@@ -1295,7 +1303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksByAddress = prepareForFallbackToLocalBlocks(
       blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
     doThrow(new RuntimeException("Forced error")).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
       blockManager = Some(blockManager))
     verifyLocalBlocksFromFallback(iterator)
@@ -1312,18 +1320,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1)
     val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)),
-      (pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
-        ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID)))
+      (pushMergedBmId, toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2),
+        ShuffleMergedBlockId(0, 0, 3)), 2L, SHUFFLE_PUSH_MAP_ID)))
     doThrow(new RuntimeException("Forced error")).when(blockManager)
-      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+      .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs)
     // Create a valid chunk meta for partition 3
     val bitmaps = Array(new RoaringBitmap)
     bitmaps(0).add(1) // chunk 0 has mapId 1
     doReturn(createMockPushMergedBlockMeta(bitmaps.length, bitmaps)).when(blockManager)
-      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs)
+      .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 3), localDirs)
     // Return valid buffer for chunk in partition 3
     doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs)
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 3), localDirs)
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
       blockManager = Some(blockManager))
     val (id1, _) = iterator.next()
@@ -1335,7 +1343,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val (id4, _) = iterator.next()
     assert(id4 === ShuffleBlockId(0, 2, 2))
     val (id5, _) = iterator.next()
-    assert(id5 === ShuffleBlockChunkId(0, 3, 0))
+    assert(id5 === ShuffleBlockChunkId(0, 0, 3, 0))
     assert(!iterator.hasNext)
   }
 
@@ -1358,7 +1366,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     // Since bitmaps are null, this will fail reading the push-merged block meta causing fallback to
     // initiate.
     val pushMergedBlockMeta: MergedBlockMeta = createMockPushMergedBlockMeta(2, null)
-    when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
+    when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2),
       dirsForMergedData)).thenReturn(pushMergedBlockMeta)
     when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
       Seq((localBmId,
@@ -1366,7 +1374,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
 
     val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
       (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), toBlockList(
-        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
+        Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
       blockManager = Some(blockManager))
     // 1st instance of iterator.next() returns the original shuffle block (0, 0, 2)
@@ -1385,7 +1393,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, hostLocalDirs)
 
     doThrow(new RuntimeException("Forced error")).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir"))
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir"))
     // host local read for a shuffle block
     doReturn(createMockManagedBuffer()).when(blockManager)
       .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir"))
@@ -1416,7 +1424,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       blockManager, hostLocalDirs) ++ hostLocalBlocks
 
     doThrow(new RuntimeException("Forced error")).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir"))
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir"))
     // host Local read for this original shuffle block
     doReturn(createMockManagedBuffer()).when(blockManager)
       .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir"))
@@ -1452,7 +1460,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     doReturn(Seq({
       new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100)
       })).when(blockManager).getLocalMergedBlockData(
-        ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+      ShuffleMergedBlockId(0, 0, 2), localDirs)
     val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
       blockManager = Some(blockManager))
     verifyLocalBlocksFromFallback(iterator)
@@ -1466,7 +1474,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
     val corruptBuffer = createMockManagedBuffer(2)
     doReturn(Seq({corruptBuffer})).when(blockManager)
-      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+      .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
     val corruptStream = mock(classOf[InputStream])
     when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
     doReturn(corruptStream).when(corruptBuffer).createInputStream()
@@ -1477,7 +1485,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
 
   test("SPARK-32922: fallback to original blocks when failed to fetch remote shuffle chunk") {
     val blockChunks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer()
@@ -1489,13 +1497,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     bitmaps(1).add(4)
     bitmaps(1).add(5)
     val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
-          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+          metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
         }
       })
     val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
@@ -1506,10 +1515,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       .thenReturn(fallbackBlocksByAddr.iterator)
     val iterator = createShuffleBlockIteratorWithDefaults(Map(
       BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) ->
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID)))
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)),
+          12L, SHUFFLE_PUSH_MAP_ID)))
     val (id1, _) = iterator.next()
     blocksSem.acquire(1)
-    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
     val (id3, _) = iterator.next()
     blocksSem.acquire(3)
     assert(id3 === ShuffleBlockId(0, 3, 2))
@@ -1531,20 +1541,21 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksSem = new Semaphore(0)
     configureMockTransferForPushShuffle(blocksSem, blockChunks)
     val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
-          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+          metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
         }
       })
     val remoteMergedBlockMgrId = BlockManagerId(
       SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1)
     val iterator = createShuffleBlockIteratorWithDefaults(
       Map(remoteMergedBlockMgrId -> toBlockList(
-        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
+        Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
     val (id1, _) = iterator.next()
     blocksSem.acquire(2)
     assert(id1 === ShuffleBlockId(0, 0, 2))
@@ -1556,10 +1567,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " +
     "pending shuffle chunks immediately") {
     val blockChunks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
       // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks.
-      ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 3) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
@@ -1574,17 +1585,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
     val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap)
     when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
           logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
             s"port = ${invocation.getArguments()(1)}, " +
-            s"shuffleId = $shuffleId, reduceId = $reduceId")
+            s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
           metaSem.release()
-          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+          metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
         }
       })
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
@@ -1596,12 +1608,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
 
     val iterator = createShuffleBlockIteratorWithDefaults(Map(
       BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)),
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)),
+          16L, SHUFFLE_PUSH_MAP_ID)),
       maxBytesInFlight = 4)
     metaSem.acquire(1)
     val (id1, _) = iterator.next()
     blocksSem.acquire(1)
-    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
     val regularBlocks = new mutable.HashSet[BlockId]()
     val (id2, _) = iterator.next()
     blocksSem.acquire(1)
@@ -1623,12 +1636,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " +
     "pending shuffle chunks immediately which got deferred") {
     val blockChunks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(),
       // ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock chunks
-      ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(),
-      ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 4) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 0, 2, 5) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
@@ -1642,17 +1655,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
     val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap)
     when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
-    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
       .thenAnswer((invocation: InvocationOnMock) => {
-        val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
         val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
-        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
         Future {
           logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
             s"port = ${invocation.getArguments()(1)}, " +
-            s"shuffleId = $shuffleId, reduceId = $reduceId")
+            s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
           metaSem.release()
-          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+          metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
         }
       })
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
@@ -1664,17 +1678,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
 
     val iterator = createShuffleBlockIteratorWithDefaults(Map(
       BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
-        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)),
+        toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 24L,
+          SHUFFLE_PUSH_MAP_ID)),
       maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1)
     metaSem.acquire(1)
     val (id1, _) = iterator.next()
     blocksSem.acquire(2)
-    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
     val (id2, _) = iterator.next()
-    assert(id2 === ShuffleBlockChunkId(0, 2, 1))
+    assert(id2 === ShuffleBlockChunkId(0, 0, 2, 1))
     val (id3, _) = iterator.next()
     blocksSem.acquire(1)
-    assert(id3 === ShuffleBlockChunkId(0, 2, 2))
+    assert(id3 === ShuffleBlockChunkId(0, 0, 2, 2))
     val regularBlocks = new mutable.HashSet[BlockId]()
     val (id4, _) = iterator.next()
     blocksSem.acquire(1)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org