You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by et...@apache.org on 2022/11/16 07:04:48 UTC

[incubator-celeborn] branch branch-0.1 updated: 1.[BUG] Fix fetch incorrect data chunk (#926) 2.[ISSUE-939][REFACTOR] Bump up ratis to 2.4.0 (#940) 3.[ISSUE-925][FOLLOWUP] Refactor class name of RetryingChunkReceiveCallback (#954) 4.Add instructions for the latest release policy.

This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch branch-0.1
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.1 by this push:
     new c86faf98 1.[BUG] Fix fetch incorrect data chunk (#926) 2.[ISSUE-939][REFACTOR] Bump up ratis to 2.4.0 (#940) 3.[ISSUE-925][FOLLOWUP] Refactor class name of RetryingChunkReceiveCallback (#954) 4.Add instructions for the latest release policy.
c86faf98 is described below

commit c86faf986b07d6fcdd60a266458a5e3882a51e03
Author: Ethan Feng <et...@apache.org>
AuthorDate: Wed Nov 9 22:31:39 2022 +0800

    1.[BUG] Fix fetch incorrect data chunk (#926)
    2.[ISSUE-939][REFACTOR] Bump up ratis to 2.4.0 (#940)
    3.[ISSUE-925][FOLLOWUP] Refactor class name of RetryingChunkReceiveCallback (#954)
    4.Add instructions for the latest release policy.
---
 README.md                                          |   4 +-
 .../{RetryingChunkClient.java => ChunkClient.java} | 192 ++++------
 .../emr/rss/client/read/PartitionReader.java       |  30 ++
 .../com/aliyun/emr/rss/client/read/Replica.java    |  94 +++++
 .../aliyun/emr/rss/client/read/RssInputStream.java | 102 +-----
 .../emr/rss/client/read/WorkerPartitionReader.java | 188 ++++++++++
 .../rss/client/read/RetryingChunkClientSuiteJ.java | 394 ---------------------
 .../network/client/ChunkReceivedCallback.java      |   5 +-
 .../rss/common/network/client/TransportClient.java |   2 +-
 .../network/client/TransportResponseHandler.java   |   9 +-
 .../scala/com/aliyun/emr/rss/common/RssConf.scala  |   4 +
 .../network/ChunkFetchIntegrationSuiteJ.java       |   5 +-
 .../network/RequestTimeoutIntegrationSuiteJ.java   |   5 +-
 .../network/TransportResponseHandlerSuiteJ.java    |  10 +-
 pom.xml                                            |   2 +-
 .../clustermeta/ha/MasterStateMachineSuiteJ.java   |   5 +-
 .../service/deploy/worker/FileWriterSuiteJ.java    |   5 +-
 17 files changed, 417 insertions(+), 639 deletions(-)

diff --git a/README.md b/README.md
index 2badec14..4f2c5e69 100644
--- a/README.md
+++ b/README.md
@@ -42,7 +42,9 @@ RSS Worker's slot count is decided by `rss.worker.numSlots` or`rss.worker.flush.
 RSS worker's slot count decreases when a partition is allocated and increments when a partition is freed.  
 
 ## Build
-RSS supports Spark2.x(>=2.4.0), Spark3.x(>=3.0.1) and only tested under Java8(JDK1.8).
+RSS supports Spark2.x(>=2.4.0), Spark3.x(>=3.0.1) and only tested under Java8(JDK1.8). 
+There won't be new release package for branch-0.1, so if you need updates and fix,
+you'll need to build your own package. 
 
 Build for Spark 2    
 `
diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java b/client/src/main/java/com/aliyun/emr/rss/client/read/ChunkClient.java
similarity index 58%
rename from client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java
rename to client/src/main/java/com/aliyun/emr/rss/client/read/ChunkClient.java
index 13672d9b..7ef96315 100644
--- a/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java
+++ b/client/src/main/java/com/aliyun/emr/rss/client/read/ChunkClient.java
@@ -18,9 +18,6 @@
 package com.aliyun.emr.rss.client.read;
 
 import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
@@ -36,9 +33,6 @@ import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer;
 import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback;
 import com.aliyun.emr.rss.common.network.client.TransportClient;
 import com.aliyun.emr.rss.common.network.client.TransportClientFactory;
-import com.aliyun.emr.rss.common.network.protocol.Message;
-import com.aliyun.emr.rss.common.network.protocol.OpenStream;
-import com.aliyun.emr.rss.common.network.protocol.StreamHandle;
 import com.aliyun.emr.rss.common.network.util.NettyUtils;
 import com.aliyun.emr.rss.common.network.util.TransportConf;
 import com.aliyun.emr.rss.common.protocol.PartitionLocation;
@@ -56,19 +50,25 @@ import com.aliyun.emr.rss.common.util.Utils;
  * not be too many. Each retry is actually a switch between Master and Slave. Therefore, each retry
  * needs to create a new connection and reopen the file to generate the stream id.
  */
-public class RetryingChunkClient {
-  private static final Logger logger = LoggerFactory.getLogger(RetryingChunkClient.class);
+public class ChunkClient {
+  private static final Logger logger = LoggerFactory.getLogger(ChunkClient.class);
   private static final ExecutorService executorService = Executors.newCachedThreadPool(
       NettyUtils.createThreadFactory("Chunk Fetch Retry"));
 
   private final ChunkReceivedCallback callback;
-  private final List<Replica> replicas;
+  private final Replica replica;
   private final long retryWaitMs;
   private final int maxTries;
 
   private volatile int numTries = 0;
+  private PartitionLocation location;
+  private int fetchFailedChunkIndex;
 
-  public RetryingChunkClient(
+  public PartitionLocation getLocation() {
+    return location;
+  }
+
+  public ChunkClient(
       RssConf conf,
       String shuffleKey,
       PartitionLocation location,
@@ -77,7 +77,7 @@ public class RetryingChunkClient {
     this(conf, shuffleKey, location, callback, clientFactory, 0, Integer.MAX_VALUE);
   }
 
-  public RetryingChunkClient(
+  public ChunkClient(
       RssConf conf,
       String shuffleKey,
       PartitionLocation location,
@@ -86,26 +86,21 @@ public class RetryingChunkClient {
       int startMapIndex,
       int endMapIndex) {
     TransportConf transportConf = Utils.fromRssConf(conf, TransportModuleConstants.DATA_MODULE, 0);
+    this.fetchFailedChunkIndex = conf.testFetchFailedChunkIndex(conf);
 
-    this.replicas = new ArrayList<>(2);
     this.callback = callback;
     this.retryWaitMs = transportConf.ioRetryWaitTimeMs();
 
     long timeoutMs = RssConf.fetchChunkTimeoutMs(conf);
-    if (location != null) {
-      replicas.add(new Replica(timeoutMs, shuffleKey, location,
-        clientFactory, startMapIndex, endMapIndex));
-      if (location.getPeer() != null) {
-        replicas.add(new Replica(timeoutMs, shuffleKey, location.getPeer(),
-          clientFactory, startMapIndex, endMapIndex));
-      }
-    }
+    this.location = location;
 
-    if (this.replicas.size() <= 0) {
+    if (location == null) {
       throw new IllegalArgumentException("Must contain at least one available PartitionLocation.");
+    } else {
+      replica = new Replica(timeoutMs, shuffleKey, location,
+              clientFactory, startMapIndex, endMapIndex);
     }
-
-    this.maxTries = (transportConf.maxIORetries() + 1) * replicas.size();
+    this.maxTries = (transportConf.maxIORetries() + 1);
   }
 
   /**
@@ -114,32 +109,47 @@ public class RetryingChunkClient {
    *
    * @return numChunks.
    */
-  public int openChunks() throws IOException {
+  public synchronized int openChunks() throws IOException {
     int numChunks = -1;
+    Exception currentException = null;
     while (numChunks == -1 && hasRemainingRetries()) {
-      Replica replica = getCurrentReplica();
+      // Only not wait for first request to each replicate.
+      if (numTries != 0) {
+        logger.info(
+                "Retrying openChunk ({}/{}) for chunk from {} after {} ms.",
+                numTries,
+                maxTries,
+                replica,
+                retryWaitMs);
+        Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
+      }
       try {
         replica.getOrOpenStream();
         numChunks = replica.getNumChunks();
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        throw new IOException(e);
       } catch (Exception e) {
-        if (e instanceof InterruptedException) {
-          Thread.currentThread().interrupt();
-          throw new IOException(e);
-        }
-
-        logger.warn("Exception raised while sending open chunks message to {}.", replica, e);
-
-        numChunks = -1;
+        logger.error("Exception raised while sending open chunks message to " + replica + ".", e);
+        currentException = e;
         if (shouldRetry(e)) {
-          numTries += 1; // openChunks will not be concurrently called.
+          numTries += 1;
         } else {
           break;
         }
       }
     }
     if (numChunks == -1) {
-      throw new IOException(String.format("Could not open chunks after %d tries.", numTries));
+      if (currentException != null) {
+        callback.onFailure(0, location, new IOException(
+          String.format("Could not open chunks from %s after %d tries.", replica, numTries),
+            currentException));
+      } else {
+        callback.onFailure(0, location, new IOException(
+          String.format("Could not open chunks from %s after %d tries.", replica, numTries)));
+      }
     }
+    numTries = 0;
     return numChunks;
   }
 
@@ -151,33 +161,36 @@ public class RetryingChunkClient {
    * @param chunkIndex the index of the chunk to be fetched.
    */
   public void fetchChunk(int chunkIndex) {
-    Replica replica;
-    RetryingChunkReceiveCallback callback;
+    FetchChunkCallback callback;
     synchronized (this) {
-      replica = getCurrentReplica();
-      callback = new RetryingChunkReceiveCallback(numTries);
+      callback = new FetchChunkCallback(numTries);
+      if (fetchFailedChunkIndex != 0
+              && location.getPeer() != null
+              && chunkIndex == fetchFailedChunkIndex
+              && location.getMode() == PartitionLocation.Mode.Master) {
+        RuntimeException manualTriggeredFailure =
+                new RuntimeException("Manual triggered fetch failure");
+        callback.onFailure(chunkIndex, location, manualTriggeredFailure);
+      }
     }
     try {
       TransportClient client = replica.getOrOpenStream();
       client.fetchChunk(replica.getStreamId(), chunkIndex, callback);
     } catch (Exception e) {
-      logger.error("Exception raised while beginning fetch chunk {} {}.",
-          chunkIndex, numTries > 0 ? "(after " + numTries + " retries)" : "", e);
+      logger.error(
+              "Exception raised while beginning fetch chunk {}{}.",
+              chunkIndex,
+              numTries > 0 ? " (after " + numTries + " retries)" : "",
+              e);
 
       if (shouldRetry(e)) {
         initiateRetry(chunkIndex, callback.currentNumTries);
       } else {
-        callback.onFailure(chunkIndex, e);
+        callback.onFailure(chunkIndex, location, e);
       }
     }
   }
 
-  @VisibleForTesting
-  Replica getCurrentReplica() {
-    int currentReplicaIndex = numTries % replicas.size();
-    return replicas.get(currentReplicaIndex);
-  }
-
   @VisibleForTesting
   int getNumTries() {
     return numTries;
@@ -200,7 +213,7 @@ public class RetryingChunkClient {
     numTries = Math.max(numTries, currentNumTries + 1);
 
     logger.info("Retrying fetch ({}/{}) for chunk {} from {} after {} ms.",
-        currentNumTries, maxTries, chunkIndex, getCurrentReplica(), retryWaitMs);
+        currentNumTries, maxTries, chunkIndex, replica, retryWaitMs);
 
     executorService.submit(() -> {
       Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
@@ -208,94 +221,27 @@ public class RetryingChunkClient {
     });
   }
 
-  private class RetryingChunkReceiveCallback implements ChunkReceivedCallback {
+  private class FetchChunkCallback implements ChunkReceivedCallback {
     final int currentNumTries;
 
-    RetryingChunkReceiveCallback(int currentNumTries) {
+    FetchChunkCallback(int currentNumTries) {
       this.currentNumTries = currentNumTries;
     }
 
     @Override
-    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
-      callback.onSuccess(chunkIndex, buffer);
+    public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
+      callback.onSuccess(chunkIndex, buffer, ChunkClient.this.location);
     }
 
     @Override
-    public void onFailure(int chunkIndex, Throwable e) {
+    public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
       if (shouldRetry(e)) {
         initiateRetry(chunkIndex, this.currentNumTries);
       } else {
-        logger.error("Failed to fetch chunk {}, and will not retry({} tries).",
-          chunkIndex, this.currentNumTries);
-        callback.onFailure(chunkIndex, e);
+        logger.error("Abandon to fetch chunk {} after {} tries.", chunkIndex, this.currentNumTries);
+        callback.onFailure(chunkIndex, ChunkClient.this.location, e);
       }
     }
   }
 }
 
-class Replica {
-  private static final Logger logger = LoggerFactory.getLogger(Replica.class);
-  private final long timeoutMs;
-  private final String shuffleKey;
-  private final PartitionLocation location;
-  private final TransportClientFactory clientFactory;
-
-  private StreamHandle streamHandle;
-  private TransportClient client;
-  private int startMapIndex;
-  private int endMapIndex;
-
-  Replica(
-      long timeoutMs,
-      String shuffleKey,
-      PartitionLocation location,
-      TransportClientFactory clientFactory,
-      int startMapIndex,
-      int endMapIndex) {
-    this.timeoutMs = timeoutMs;
-    this.shuffleKey = shuffleKey;
-    this.location = location;
-    this.clientFactory = clientFactory;
-    this.startMapIndex = startMapIndex;
-    this.endMapIndex = endMapIndex;
-  }
-
-  Replica(
-      long timeoutMs,
-      String shuffleKey,
-      PartitionLocation location,
-      TransportClientFactory clientFactory) {
-    this(timeoutMs, shuffleKey, location, clientFactory, 0, Integer.MAX_VALUE);
-  }
-
-  public synchronized TransportClient getOrOpenStream()
-      throws IOException, InterruptedException {
-    if (client == null || !client.isActive()) {
-      client = clientFactory.createClient(location.getHost(), location.getFetchPort());
-
-      OpenStream openBlocks = new OpenStream(shuffleKey, location.getFileName(),
-        startMapIndex, endMapIndex);
-      ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs);
-      streamHandle = (StreamHandle) Message.decode(response);
-    }
-    return client;
-  }
-
-  public long getStreamId() {
-    return streamHandle.streamId;
-  }
-
-  public int getNumChunks() {
-    return streamHandle.numChunks;
-  }
-
-  @Override
-  public String toString() {
-    return location.getHost() + ":" + location.getFetchPort();
-  }
-
-  @VisibleForTesting
-  PartitionLocation getLocation() {
-    return location;
-  }
-}
diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/PartitionReader.java b/client/src/main/java/com/aliyun/emr/rss/client/read/PartitionReader.java
new file mode 100644
index 00000000..6baef396
--- /dev/null
+++ b/client/src/main/java/com/aliyun/emr/rss/client/read/PartitionReader.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.aliyun.emr.rss.client.read;
+
+import java.io.IOException;
+
+import io.netty.buffer.ByteBuf;
+
+public interface PartitionReader {
+  boolean hasNext();
+
+  ByteBuf next() throws IOException;
+
+  void close();
+}
diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/Replica.java b/client/src/main/java/com/aliyun/emr/rss/client/read/Replica.java
new file mode 100644
index 00000000..60fb46ac
--- /dev/null
+++ b/client/src/main/java/com/aliyun/emr/rss/client/read/Replica.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.aliyun.emr.rss.client.read;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import com.aliyun.emr.rss.common.network.client.TransportClient;
+import com.aliyun.emr.rss.common.network.client.TransportClientFactory;
+import com.aliyun.emr.rss.common.network.protocol.Message;
+import com.aliyun.emr.rss.common.network.protocol.OpenStream;
+import com.aliyun.emr.rss.common.network.protocol.StreamHandle;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
+
+class Replica {
+  private final long timeoutMs;
+  private final String shuffleKey;
+  private final PartitionLocation location;
+  private final TransportClientFactory clientFactory;
+
+  private StreamHandle streamHandle;
+  private TransportClient client;
+  private final int startMapIndex;
+  private final int endMapIndex;
+
+  Replica(
+      long timeoutMs,
+      String shuffleKey,
+      PartitionLocation location,
+      TransportClientFactory clientFactory,
+      int startMapIndex,
+      int endMapIndex) {
+    this.timeoutMs = timeoutMs;
+    this.shuffleKey = shuffleKey;
+    this.location = location;
+    this.clientFactory = clientFactory;
+    this.startMapIndex = startMapIndex;
+    this.endMapIndex = endMapIndex;
+  }
+
+  public synchronized TransportClient getOrOpenStream() throws IOException, InterruptedException {
+    if (client == null || !client.isActive()) {
+      client = clientFactory.createClient(location.getHost(), location.getFetchPort());
+    }
+    if (streamHandle == null) {
+      OpenStream openBlocks =
+          new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex);
+      ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs);
+      streamHandle = (StreamHandle) Message.decode(response);
+    }
+    return client;
+  }
+
+  public long getStreamId() {
+    return streamHandle.streamId;
+  }
+
+  public int getNumChunks() {
+    return streamHandle.numChunks;
+  }
+
+  @Override
+  public String toString() {
+    String shufflePartition =
+        String.format("%s:%d %s", location.getHost(), location.getFetchPort(), shuffleKey);
+    if (startMapIndex == 0 && endMapIndex == Integer.MAX_VALUE) {
+      return shufflePartition;
+    } else {
+      return String.format("%s[%d,%d)", shufflePartition, startMapIndex, endMapIndex);
+    }
+  }
+
+  @VisibleForTesting
+  PartitionLocation getLocation() {
+    return location;
+  }
+}
diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java b/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java
index 822fb63f..1298657a 100644
--- a/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java
+++ b/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java
@@ -27,9 +27,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.concurrent.atomic.LongAdder;
 
 import io.netty.buffer.ByteBuf;
@@ -39,9 +36,6 @@ import org.slf4j.LoggerFactory;
 
 import com.aliyun.emr.rss.client.compress.Decompressor;
 import com.aliyun.emr.rss.common.RssConf;
-import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer;
-import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer;
-import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback;
 import com.aliyun.emr.rss.common.network.client.TransportClientFactory;
 import com.aliyun.emr.rss.common.protocol.PartitionLocation;
 import com.aliyun.emr.rss.common.unsafe.Platform;
@@ -227,7 +221,8 @@ public abstract class RssInputStream extends InputStream {
         logger.debug("Read peer {} for attempt {}.", location, attemptNumber);
       }
 
-      return new PartitionReader(location);
+      return new WorkerPartitionReader(
+              conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
     }
 
     public void setCallback(MetricsCallback callback) {
@@ -376,98 +371,5 @@ public abstract class RssInputStream extends InputStream {
       return hasData;
     }
 
-    private final class PartitionReader {
-      private final RetryingChunkClient client;
-      private final int numChunks;
-
-      private int returnedChunks;
-      private int chunkIndex;
-
-      private final LinkedBlockingQueue<ByteBuf> results;
-      private final ChunkReceivedCallback callback;
-
-      private final AtomicReference<IOException> exception = new AtomicReference<>();
-
-      private boolean closed = false;
-
-      PartitionReader(PartitionLocation location) throws IOException {
-        results = new LinkedBlockingQueue<>();
-        callback = new ChunkReceivedCallback() {
-          @Override
-          public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
-            // only add the buffer to results queue if this reader is not closed.
-            synchronized(PartitionReader.this) {
-              ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf();
-              if (!closed) {
-                buf.retain();
-                results.add(buf);
-              }
-            }
-          }
-
-          @Override
-          public void onFailure(int chunkIndex, Throwable e) {
-            String errorMsg = "Fetch chunk " + chunkIndex + " failed.";
-            logger.error(errorMsg, e);
-            exception.set(new IOException(errorMsg, e));
-          }
-        };
-        client = new RetryingChunkClient(conf, shuffleKey, location,
-          callback, clientFactory, startMapIndex, endMapIndex);
-        numChunks = client.openChunks();
-      }
-
-      boolean hasNext() {
-        return returnedChunks < numChunks;
-      }
-
-      ByteBuf next() throws IOException {
-        checkException();
-        if (chunkIndex < numChunks) {
-          fetchChunks();
-        }
-        ByteBuf chunk = null;
-        try {
-          while (chunk == null) {
-            checkException();
-            chunk = results.poll(500, TimeUnit.MILLISECONDS);
-          }
-        } catch (InterruptedException e) {
-          Thread.currentThread().interrupt();
-          IOException ioe = new IOException(e);
-          exception.set(ioe);
-          throw ioe;
-        }
-        returnedChunks++;
-        return chunk;
-      }
-
-      void close() {
-        synchronized(this) {
-          closed = true;
-        }
-        if (results.size() > 0) {
-          results.forEach(res -> res.release());
-        }
-        results.clear();
-      }
-
-      private void fetchChunks() {
-        final int inFlight = chunkIndex - returnedChunks;
-        if (inFlight < maxInFlight) {
-          final int toFetch = Math.min(maxInFlight - inFlight + 1, numChunks - chunkIndex);
-          for (int i = 0; i < toFetch; i++) {
-            client.fetchChunk(chunkIndex++);
-          }
-        }
-      }
-
-      private void checkException() throws IOException {
-        IOException e = exception.get();
-        if (e != null) {
-          throw e;
-        }
-      }
-    }
   }
 }
diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/WorkerPartitionReader.java b/client/src/main/java/com/aliyun/emr/rss/client/read/WorkerPartitionReader.java
new file mode 100644
index 00000000..5cc28520
--- /dev/null
+++ b/client/src/main/java/com/aliyun/emr/rss/client/read/WorkerPartitionReader.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.aliyun.emr.rss.client.read;
+
+import java.io.IOException;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.buffer.ByteBuf;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.aliyun.emr.rss.common.RssConf;
+import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer;
+import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer;
+import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback;
+import com.aliyun.emr.rss.common.network.client.TransportClientFactory;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
+
+public class WorkerPartitionReader implements PartitionReader {
+  private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class);
+  private ChunkClient client;
+  private int numChunks;
+
+  private int returnedChunks;
+  private int currentChunkIndex;
+
+  private final LinkedBlockingQueue<ChunkData> results;
+
+  private final AtomicReference<IOException> exception = new AtomicReference<>();
+  private final int fetchMaxReqsInFlight;
+  private AtomicBoolean closed = new AtomicBoolean(false);
+  private Set<PartitionLocation> readableLocations = ConcurrentHashMap.newKeySet();
+  private Set<PartitionLocation> failedLocations = ConcurrentHashMap.newKeySet();
+
+  WorkerPartitionReader(
+      RssConf conf,
+      String shuffleKey,
+      PartitionLocation location,
+      TransportClientFactory clientFactory,
+      int startMapIndex,
+      int endMapIndex)
+      throws IOException {
+    fetchMaxReqsInFlight = conf.fetchChunkMaxReqsInFlight(conf);
+    results = new LinkedBlockingQueue<>();
+    readableLocations.add(location);
+    if (location.getPeer() != null) {
+      readableLocations.add(location.getPeer());
+    }
+    // only add the buffer to results queue if this reader is not closed.
+    ChunkReceivedCallback callback =
+      new ChunkReceivedCallback() {
+        @Override
+        public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
+          // only add the buffer to results queue if this reader is not closed.
+          ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf();
+          if (!closed.get() && !failedLocations.contains(location)) {
+            buf.retain();
+            results.add(new ChunkData(buf, location));
+          }
+        }
+
+        @Override
+        public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
+          readableLocations.remove(location);
+          if (readableLocations.isEmpty()) {
+            String errorMsg = "Fetch chunk " + chunkIndex + " failed.";
+            logger.error(errorMsg, e);
+            exception.set(new IOException(errorMsg, e));
+          } else {
+            try {
+              synchronized (WorkerPartitionReader.this) {
+                if (!failedLocations.contains(location)) {
+                  failedLocations.add(location);
+                  client = new ChunkClient(conf, shuffleKey, location.getPeer(),
+                    this, clientFactory, startMapIndex, endMapIndex);
+                  currentChunkIndex = 0;
+                  returnedChunks = 0;
+                  numChunks = client.openChunks();
+                }
+              }
+            } catch (IOException e1) {
+              logger.error(e1.getMessage(), e1);
+              exception.set(new IOException(e1.getMessage(), e1));
+            }
+          }
+        }
+    };
+    client = new ChunkClient(conf, shuffleKey, location, callback, clientFactory,
+            startMapIndex, endMapIndex);
+    numChunks = client.openChunks();
+  }
+
+  public synchronized boolean hasNext() {
+    return returnedChunks < numChunks;
+  }
+
+  public ByteBuf next() throws IOException {
+    checkException();
+    synchronized (this) {
+      if (currentChunkIndex < numChunks) {
+        fetchChunks();
+      }
+    }
+    ByteBuf chunk = null;
+    try {
+      while (chunk == null) {
+        checkException();
+        ChunkData chunkData = results.poll(500, TimeUnit.MILLISECONDS);
+        if (chunkData != null) {
+          synchronized (this) {
+            if (failedLocations.contains(chunkData.location)) {
+              chunkData.release();
+            } else {
+              chunk = chunkData.buf;
+              returnedChunks++;
+            }
+          }
+        }
+      }
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      IOException ioe = new IOException(e);
+      exception.set(ioe);
+      throw ioe;
+    }
+    return chunk;
+  }
+
+  public void close() {
+    closed.set(true);
+    if (results.size() > 0) {
+      results.forEach(ChunkData::release);
+    }
+    results.clear();
+  }
+
+  private void fetchChunks() {
+    final int inFlight = currentChunkIndex - returnedChunks;
+    if (inFlight < fetchMaxReqsInFlight) {
+      final int toFetch =
+          Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - currentChunkIndex);
+      for (int i = 0; i < toFetch; i++) {
+        client.fetchChunk(currentChunkIndex++);
+      }
+    }
+  }
+
+  private void checkException() throws IOException {
+    IOException e = exception.get();
+    if (e != null) {
+      throw e;
+    }
+  }
+
+  private static class ChunkData {
+    ByteBuf buf;
+    PartitionLocation location;
+
+    ChunkData(ByteBuf buf, PartitionLocation location) {
+      this.buf = buf;
+      this.location = location;
+    }
+
+    public void release() {
+      buf.release();
+    }
+  }
+}
diff --git a/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java b/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java
deleted file mode 100644
index 1a41e6e8..00000000
--- a/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java
+++ /dev/null
@@ -1,394 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.aliyun.emr.rss.client.read;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.LinkedHashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
-
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Sets;
-import io.netty.channel.Channel;
-import org.junit.Test;
-import org.mockito.stubbing.Answer;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-import static org.mockito.Mockito.any;
-import static org.mockito.Mockito.anyInt;
-import static org.mockito.Mockito.anyObject;
-import static org.mockito.Mockito.anyString;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.timeout;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-
-import com.aliyun.emr.rss.common.RssConf;
-import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer;
-import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer;
-import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback;
-import com.aliyun.emr.rss.common.network.client.TransportClient;
-import com.aliyun.emr.rss.common.network.client.TransportClientFactory;
-import com.aliyun.emr.rss.common.network.client.TransportResponseHandler;
-import com.aliyun.emr.rss.common.network.protocol.StreamHandle;
-import com.aliyun.emr.rss.common.protocol.PartitionLocation;
-import com.aliyun.emr.rss.common.util.ThreadUtils;
-
-public class RetryingChunkClientSuiteJ {
-
-  private static final int MASTER_RPC_PORT = 1234;
-  private static final int MASTER_PUSH_PORT = 1235;
-  private static final int MASTER_FETCH_PORT = 1236;
-  private static final int MASTER_REPLICATE_PORT = 1237;
-  private static final int SLAVE_RPC_PORT = 4321;
-  private static final int SLAVE_PUSH_PORT = 4322;
-  private static final int SLAVE_FETCH_PORT = 4323;
-  private static final int SLAVE_REPLICATE_PORT = 4324;
-  private static final PartitionLocation masterLocation = new PartitionLocation(
-    0, 1, "localhost", MASTER_RPC_PORT, MASTER_PUSH_PORT, MASTER_FETCH_PORT,
-    MASTER_REPLICATE_PORT, PartitionLocation.Mode.Master);
-  private static final PartitionLocation slaveLocation = new PartitionLocation(
-    0, 1, "localhost", SLAVE_RPC_PORT, SLAVE_PUSH_PORT, SLAVE_FETCH_PORT,
-    SLAVE_REPLICATE_PORT, PartitionLocation.Mode.Slave);
-
-  static {
-    masterLocation.setPeer(slaveLocation);
-    slaveLocation.setPeer(masterLocation);
-  }
-
-  ManagedBuffer chunk0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
-  ManagedBuffer chunk1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
-  ManagedBuffer chunk2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
-
-  @Test
-  public void testNoFailures() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-        .put(0, Arrays.asList(chunk0))
-        .put(1, Arrays.asList(chunk1))
-        .put(2, Arrays.asList(chunk2))
-        .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
-    verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
-    verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testUnrecoverableFailure() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-        .put(0, Arrays.asList(new RuntimeException("Ouch!")))
-        .put(1, Arrays.asList(chunk1))
-        .put(2, Arrays.asList(chunk2))
-        .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    verify(callback, timeout(5000)).onFailure(eq(0), any());
-    verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
-    verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testDuplicateSuccess() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-        .put(0, Arrays.asList(chunk0, chunk1))
-        .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testSingleIOException() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer = invocation -> {
-      synchronized (signal) {
-        int chunkIndex = (Integer) invocation.getArguments()[0];
-        assertFalse(result.containsKey(chunkIndex));
-        Object value = invocation.getArguments()[1];
-        result.put(chunkIndex, value);
-        signal.release();
-      }
-      return null;
-    };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-      .put(0, Arrays.asList(new IOException(), chunk0))
-      .put(1, Arrays.asList(chunk1))
-      .put(2, Arrays.asList(chunk2))
-      .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS));
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testTwoIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer = invocation -> {
-      synchronized (signal) {
-        int chunkIndex = (Integer) invocation.getArguments()[0];
-        assertFalse(result.containsKey(chunkIndex));
-        Object value = invocation.getArguments()[1];
-        result.put(chunkIndex, value);
-        signal.release();
-      }
-      return null;
-    };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-      .put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
-      .put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
-      .put(2, Arrays.asList(chunk2))
-      .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS));
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testThreeIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer = invocation -> {
-      synchronized (signal) {
-        int chunkIndex = (Integer) invocation.getArguments()[0];
-        assertFalse(result.containsKey(chunkIndex));
-        Object value = invocation.getArguments()[1];
-        result.put(chunkIndex, value);
-        signal.release();
-      }
-      return null;
-    };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-      .put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
-      .put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
-      .put(2, Arrays.asList(new IOException("third ioexception"), chunk2))
-      .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS));
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testFailedWithIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer = invocation -> {
-      synchronized (signal) {
-        int chunkIndex = (Integer) invocation.getArguments()[0];
-        assertFalse(result.containsKey(chunkIndex));
-        Object value = invocation.getArguments()[1];
-        result.put(chunkIndex, value);
-        signal.release();
-      }
-      return null;
-    };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    IOException ioe = new IOException("failed exception");
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-      .put(0, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk0))
-      .put(1, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk1))
-      .put(2, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk2))
-      .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS));
-    // Note: this may exceeds the max retries we want, but it doesn't master.
-    assertEquals(4, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-    assertEquals(ioe, result.get(0));
-    assertEquals(ioe, result.get(1));
-    assertEquals(ioe, result.get(2));
-  }
-
-  @Test
-  public void testRetryAndUnrecoverable() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer = invocation -> {
-      synchronized (signal) {
-        int chunkIndex = (Integer) invocation.getArguments()[0];
-        assertFalse(result.containsKey(chunkIndex));
-        Object value = invocation.getArguments()[1];
-        result.put(chunkIndex, value);
-        signal.release();
-      }
-      return null;
-    };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Exception re = new RuntimeException("failed exception");
-    Map<Integer, List<Object>> interactions = ImmutableMap.<Integer, List<Object>>builder()
-      .put(0, Arrays.asList(new IOException("first ioexception"), re, chunk0))
-      .put(1, Arrays.asList(chunk1))
-      .put(2, Arrays.asList(new IOException("second ioexception"), chunk2))
-      .build();
-
-    performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS));
-    assertEquals(re, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  private static RetryingChunkClient performInteractions(
-      Map<Integer, List<Object>> interactions,
-      ChunkReceivedCallback callback) throws IOException, InterruptedException {
-    RssConf conf = new RssConf();
-    conf.set("rss.data.io.maxRetries", "1");
-    conf.set("rss.data.io.retryWait", "0");
-
-    // Contains all chunk ids that are referenced across all interactions.
-    LinkedHashSet<Integer> chunkIds = Sets.newLinkedHashSet(interactions.keySet());
-
-    final TransportClient client = new DummyTransportClient(chunkIds.size(), interactions);
-    final TransportClientFactory clientFactory = mock(TransportClientFactory.class);
-    doAnswer(invocation -> client).when(clientFactory).createClient(anyString(), anyInt());
-
-    RetryingChunkClient retryingChunkClient = new RetryingChunkClient(
-        conf, "test", masterLocation, callback, clientFactory);
-    chunkIds.stream().sorted().forEach(retryingChunkClient::fetchChunk);
-    return retryingChunkClient;
-  }
-
-  private static class DummyTransportClient extends TransportClient {
-
-    private static final Channel channel = mock(Channel.class);
-    private static final TransportResponseHandler handler = mock(TransportResponseHandler.class);
-
-    private final long streamId = new Random().nextInt(Integer.MAX_VALUE) * 1000L;
-    private final int numChunks;
-    private final Map<Integer, List<Object>> interactions;
-    private final Map<Integer, Integer> chunkIdToInterActionIndex;
-
-    private final ScheduledExecutorService schedule =
-      ThreadUtils.newDaemonThreadPoolScheduledExecutor("test-fetch-chunk", 3);
-
-    DummyTransportClient(int numChunks, Map<Integer, List<Object>> interactions) {
-      super(channel, handler);
-      this.numChunks = numChunks;
-      this.interactions = interactions;
-      this.chunkIdToInterActionIndex = new ConcurrentHashMap<>();
-      interactions.keySet().forEach((chunkId) -> chunkIdToInterActionIndex.putIfAbsent(chunkId, 0));
-    }
-
-    @Override
-    public void fetchChunk(long streamId, int chunkId, ChunkReceivedCallback callback) {
-      schedule.schedule(() -> {
-        Object action;
-        List<Object> interaction = interactions.get(chunkId);
-        synchronized (chunkIdToInterActionIndex) {
-          int index = chunkIdToInterActionIndex.get(chunkId);
-          assertTrue(index < interaction.size());
-          action = interaction.get(index);
-          chunkIdToInterActionIndex.put(chunkId, index + 1);
-        }
-
-        if (action instanceof ManagedBuffer) {
-          callback.onSuccess(chunkId, (ManagedBuffer) action);
-        } else if (action instanceof Exception) {
-          callback.onFailure(chunkId, (Exception) action);
-        } else {
-          fail("Can only handle ManagedBuffers and Exceptions, got " + action);
-        }
-      }, 500, TimeUnit.MILLISECONDS);
-    }
-
-    @Override
-    public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
-      StreamHandle handle = new StreamHandle(streamId, numChunks);
-      return handle.toByteBuffer();
-    }
-
-    @Override
-    public void close() {
-      super.close();
-      schedule.shutdownNow();
-    }
-  }
-}
diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/ChunkReceivedCallback.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/ChunkReceivedCallback.java
index f87ffa7d..88cb30b2 100644
--- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/ChunkReceivedCallback.java
+++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/ChunkReceivedCallback.java
@@ -18,6 +18,7 @@
 package com.aliyun.emr.rss.common.network.client;
 
 import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
 
 /**
  * Callback for the result of a single chunk result. For a single stream, the callbacks are
@@ -34,7 +35,7 @@ public interface ChunkReceivedCallback {
    * call returns. You must therefore either retain() the buffer or copy its contents before
    * returning.
    */
-  void onSuccess(int chunkIndex, ManagedBuffer buffer);
+  void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location);
 
   /**
    * Called upon failure to fetch a particular chunk. Note that this may actually be called due
@@ -43,5 +44,5 @@ public interface ChunkReceivedCallback {
    * After receiving a failure, the stream may or may not be valid. The client should not assume
    * that the server's side of the stream has been closed.
    */
-  void onFailure(int chunkIndex, Throwable e);
+  void onFailure(int chunkIndex, PartitionLocation location, Throwable e);
 }
diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java
index 6c97c1d3..00138369 100644
--- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java
+++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java
@@ -132,7 +132,7 @@ public class TransportClient implements Closeable {
       @Override
       protected void handleFailure(String errorMsg, Throwable cause) {
         handler.removeFetchRequest(streamChunkSlice);
-        callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
+        callback.onFailure(chunkIndex, null, new IOException(errorMsg, cause));
       }
     };
     handler.addFetchRequest(streamChunkSlice, callback);
diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java
index 32e3349c..533208cc 100644
--- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java
+++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java
@@ -86,7 +86,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
   private void failOutstandingRequests(Throwable cause) {
     for (Map.Entry<StreamChunkSlice, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
       try {
-        entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+        entry.getValue().onFailure(entry.getKey().chunkIndex, null, cause);
       } catch (Exception e) {
         logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
       }
@@ -139,7 +139,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
         resp.body().release();
       } else {
         outstandingFetches.remove(resp.streamChunkSlice);
-        listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body());
+        listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body(), null);
         resp.body().release();
       }
     } else if (message instanceof ChunkFetchFailure) {
@@ -151,8 +151,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
       } else {
         outstandingFetches.remove(resp.streamChunkSlice);
         logger.warn("Receive ChunkFetchFailure, errorMsg {}", resp.errorString);
-        listener.onFailure(resp.streamChunkSlice.chunkIndex, new ChunkFetchFailureException(
-          "Failure while fetching " + resp.streamChunkSlice + ": " + resp.errorString));
+        listener.onFailure(resp.streamChunkSlice.chunkIndex, null,
+                new ChunkFetchFailureException("Failure while fetching " +
+                        resp.streamChunkSlice + ": " + resp.errorString));
       }
     } else if (message instanceof RpcResponse) {
       RpcResponse resp = (RpcResponse) message;
diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala
index 91209338..732d305f 100644
--- a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala
+++ b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala
@@ -835,6 +835,10 @@ object RssConf extends Logging {
     conf.getTimeAsMs("rss.rpc.cache.expire", "15s")
   }
 
+  def testFetchFailedChunkIndex(conf: RssConf): Int = {
+    conf.getInt("rss.test.client.fetchFailedChuckIndex", 2)
+  }
+
   val WorkingDirName = "hadoop/rss-worker/shuffle_data"
 
   // If we want to use multi-raft group we can
diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java
index 24994a3a..55a05dbc 100644
--- a/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java
+++ b/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java
@@ -46,6 +46,7 @@ import com.aliyun.emr.rss.common.network.server.StreamManager;
 import com.aliyun.emr.rss.common.network.server.TransportServer;
 import com.aliyun.emr.rss.common.network.util.MapConfigProvider;
 import com.aliyun.emr.rss.common.network.util.TransportConf;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
 
 public class ChunkFetchIntegrationSuiteJ {
   static final long STREAM_ID = 1;
@@ -151,7 +152,7 @@ public class ChunkFetchIntegrationSuiteJ {
 
     ChunkReceivedCallback callback = new ChunkReceivedCallback() {
       @Override
-      public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+      public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
         buffer.retain();
         res.successChunks.add(chunkIndex);
         res.buffers.add(buffer);
@@ -159,7 +160,7 @@ public class ChunkFetchIntegrationSuiteJ {
       }
 
       @Override
-      public void onFailure(int chunkIndex, Throwable e) {
+      public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
         res.failedChunks.add(chunkIndex);
         sem.release();
       }
diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java
index 04e55e0e..df62af47 100644
--- a/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java
+++ b/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java
@@ -43,6 +43,7 @@ import com.aliyun.emr.rss.common.network.server.StreamManager;
 import com.aliyun.emr.rss.common.network.server.TransportServer;
 import com.aliyun.emr.rss.common.network.util.MapConfigProvider;
 import com.aliyun.emr.rss.common.network.util.TransportConf;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
 
 /**
  * Suite which ensures that requests that go without a response for the network timeout period are
@@ -263,7 +264,7 @@ public class RequestTimeoutIntegrationSuiteJ {
     }
 
     @Override
-    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+    public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
       try {
         successLength = buffer.nioByteBuffer().remaining();
       } catch (IOException e) {
@@ -274,7 +275,7 @@ public class RequestTimeoutIntegrationSuiteJ {
     }
 
     @Override
-    public void onFailure(int chunkIndex, Throwable e) {
+    public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
       failure = e;
       latch.countDown();
     }
diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java
index e8c696f7..b17a3194 100644
--- a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java
+++ b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java
@@ -41,7 +41,7 @@ public class TransportResponseHandlerSuiteJ {
     assertEquals(1, handler.numOutstandingRequests());
 
     handler.handle(new ChunkFetchSuccess(streamChunkSlice, new TestManagedBuffer(123)));
-    verify(callback, times(1)).onSuccess(eq(0), any());
+    verify(callback, times(1)).onSuccess(eq(0), any(), any());
     assertEquals(0, handler.numOutstandingRequests());
   }
 
@@ -54,7 +54,7 @@ public class TransportResponseHandlerSuiteJ {
     assertEquals(1, handler.numOutstandingRequests());
 
     handler.handle(new ChunkFetchFailure(streamChunkSlice, "some error msg"));
-    verify(callback, times(1)).onFailure(eq(0), any());
+    verify(callback, times(1)).onFailure(eq(0), any(), any());
     assertEquals(0, handler.numOutstandingRequests());
   }
 
@@ -71,9 +71,9 @@ public class TransportResponseHandlerSuiteJ {
     handler.exceptionCaught(new Exception("duh duh duhhhh"));
 
     // should fail both b2 and b3
-    verify(callback, times(1)).onSuccess(eq(0), any());
-    verify(callback, times(1)).onFailure(eq(1), any());
-    verify(callback, times(1)).onFailure(eq(2), any());
+    verify(callback, times(1)).onSuccess(eq(0), any() ,any());
+    verify(callback, times(1)).onFailure(eq(1), any() ,any());
+    verify(callback, times(1)).onFailure(eq(2), any() ,any());
     assertEquals(0, handler.numOutstandingRequests());
   }
 
diff --git a/pom.xml b/pom.xml
index f43630c0..b50cb25b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -42,7 +42,7 @@
     <codahale.metrics.version>3.2.6</codahale.metrics.version>
     <javaxservlet.version>3.1.0</javaxservlet.version>
     <!-- Apache Ratis version -->
-    <ratis.version>2.2.0</ratis.version>
+    <ratis.version>2.4.0</ratis.version>
     <!-- ProtocolBuffer version, used to verify the protoc version and -->
     <!-- define the protobuf JAR version                               -->
     <protobuf.version>3.5.1</protobuf.version><!-- Maven protoc compiler -->
diff --git a/server-master/src/test/java/com/aliyun/emr/rss/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java b/server-master/src/test/java/com/aliyun/emr/rss/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
index fa10ac4d..1ea4d6ff 100644
--- a/server-master/src/test/java/com/aliyun/emr/rss/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
+++ b/server-master/src/test/java/com/aliyun/emr/rss/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
@@ -27,7 +27,7 @@ import java.util.concurrent.ThreadLocalRandom;
 import java.util.regex.Matcher;
 
 import org.apache.ratis.server.storage.RaftStorage;
-import org.apache.ratis.server.storage.RaftStorageImpl;
+import org.apache.ratis.server.storage.StorageImplUtils;
 import org.apache.ratis.statemachine.SnapshotRetentionPolicy;
 import org.apache.ratis.statemachine.impl.SimpleStateMachineStorage;
 import org.junit.Assert;
@@ -76,7 +76,8 @@ public class MasterStateMachineSuiteJ extends RatisBaseSuiteJ {
 
     File storageDir = Utils.createTempDir("./", "snapshot");
 
-    final RaftStorage storage = new RaftStorageImpl(storageDir, null, 100);
+    final RaftStorage storage = StorageImplUtils.newRaftStorage(storageDir, null,
+      RaftStorage.StartupOption.FORMAT, 100);
     SimpleStateMachineStorage simpleStateMachineStorage =
       (SimpleStateMachineStorage)stateMachine.getStateMachineStorage();
     simpleStateMachineStorage.init(storage);
diff --git a/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java b/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java
index 6f7c49d1..f5653dcd 100644
--- a/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java
+++ b/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java
@@ -56,6 +56,7 @@ import com.aliyun.emr.rss.common.network.server.TransportServer;
 import com.aliyun.emr.rss.common.network.util.JavaUtils;
 import com.aliyun.emr.rss.common.network.util.MapConfigProvider;
 import com.aliyun.emr.rss.common.network.util.TransportConf;
+import com.aliyun.emr.rss.common.protocol.PartitionLocation;
 import com.aliyun.emr.rss.common.protocol.PartitionSplitMode;
 import com.aliyun.emr.rss.common.util.ThreadUtils;
 import com.aliyun.emr.rss.common.util.Utils;
@@ -181,7 +182,7 @@ public class FileWriterSuiteJ {
 
     ChunkReceivedCallback callback = new ChunkReceivedCallback() {
       @Override
-      public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+      public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
         buffer.retain();
         res.successChunks.add(chunkIndex);
         res.buffers.add(buffer);
@@ -189,7 +190,7 @@ public class FileWriterSuiteJ {
       }
 
       @Override
-      public void onFailure(int chunkIndex, Throwable e) {
+      public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
         res.failedChunks.add(chunkIndex);
         sem.release();
       }