You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2022/11/26 10:06:10 UTC

[incubator-celeborn] branch main updated: [CELEBORN-68] Client might fetch incorrect data chunk (#1010)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 9214b821 [CELEBORN-68] Client might fetch incorrect data chunk (#1010)
9214b821 is described below

commit 9214b821818ac6cecde2fd4f06bc03ad0c2a801f
Author: Keyong Zhou <zh...@apache.org>
AuthorDate: Sat Nov 26 18:06:06 2022 +0800

    [CELEBORN-68] Client might fetch incorrect data chunk (#1010)
---
 .../celeborn/client/read/DfsPartitionReader.java   |   8 +
 .../celeborn/client/read/PartitionReader.java      |   4 +
 .../org/apache/celeborn/client/read/Replica.java   |  93 -----
 .../celeborn/client/read/RetryingChunkClient.java  | 263 -------------
 .../celeborn/client/read/RssInputStream.java       |  80 +++-
 .../client/read/WorkerPartitionReader.java         |  61 ++-
 .../client/read/RetryingChunkClientSuiteJ.java     | 423 ---------------------
 .../common/network/client/TransportClient.java     |   8 +-
 .../org/apache/celeborn/common/CelebornConf.scala  |  18 +
 docs/configuration/client.md                       |   2 +
 .../deploy/worker/storage/FileWriterSuiteJ.java    |   2 +-
 11 files changed, 157 insertions(+), 805 deletions(-)

diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index b5e8ee18..b6ebc884 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -44,6 +44,7 @@ import org.apache.celeborn.common.util.ShuffleBlockInfoUtils;
 import org.apache.celeborn.common.util.Utils;
 
 public class DfsPartitionReader implements PartitionReader {
+  PartitionLocation location;
   private final int shuffleChunkSize;
   private final int fetchMaxReqsInFlight;
   private final LinkedBlockingQueue<ByteBuf> results;
@@ -66,6 +67,8 @@ public class DfsPartitionReader implements PartitionReader {
     fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight();
     results = new LinkedBlockingQueue<>();
 
+    this.location = location;
+
     final List<Long> chunkOffsets = new ArrayList<>();
     if (endMapIndex != Integer.MAX_VALUE) {
       long fetchTimeoutMs = conf.fetchTimeoutMs();
@@ -194,4 +197,9 @@ public class DfsPartitionReader implements PartitionReader {
     }
     results.clear();
   }
+
+  @Override
+  public PartitionLocation getLocation() {
+    return location;
+  }
 }
diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java
index 44aa63d7..fa88e8c2 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java
@@ -21,10 +21,14 @@ import java.io.IOException;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.celeborn.common.protocol.PartitionLocation;
+
 public interface PartitionReader {
   boolean hasNext();
 
   ByteBuf next() throws IOException;
 
   void close();
+
+  PartitionLocation getLocation();
 }
diff --git a/client/src/main/java/org/apache/celeborn/client/read/Replica.java b/client/src/main/java/org/apache/celeborn/client/read/Replica.java
deleted file mode 100644
index 1fe7d3b3..00000000
--- a/client/src/main/java/org/apache/celeborn/client/read/Replica.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.client.read;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import com.google.common.annotations.VisibleForTesting;
-
-import org.apache.celeborn.common.network.client.TransportClient;
-import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
-import org.apache.celeborn.common.protocol.PartitionLocation;
-
-class Replica {
-  private final long timeoutMs;
-  private final String shuffleKey;
-  private final PartitionLocation location;
-  private final TransportClientFactory clientFactory;
-
-  private StreamHandle streamHandle;
-  private TransportClient client;
-  private final int startMapIndex;
-  private final int endMapIndex;
-
-  Replica(
-      long timeoutMs,
-      String shuffleKey,
-      PartitionLocation location,
-      TransportClientFactory clientFactory,
-      int startMapIndex,
-      int endMapIndex) {
-    this.timeoutMs = timeoutMs;
-    this.shuffleKey = shuffleKey;
-    this.location = location;
-    this.clientFactory = clientFactory;
-    this.startMapIndex = startMapIndex;
-    this.endMapIndex = endMapIndex;
-  }
-
-  public synchronized TransportClient getOrOpenStream() throws IOException, InterruptedException {
-    if (client == null || !client.isActive()) {
-      client = clientFactory.createClient(location.getHost(), location.getFetchPort());
-
-      OpenStream openBlocks =
-          new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex);
-      ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs);
-      streamHandle = (StreamHandle) Message.decode(response);
-    }
-    return client;
-  }
-
-  public long getStreamId() {
-    return streamHandle.streamId;
-  }
-
-  public int getNumChunks() {
-    return streamHandle.numChunks;
-  }
-
-  @Override
-  public String toString() {
-    String shufflePartition =
-        String.format("%s:%d %s", location.getHost(), location.getFetchPort(), shuffleKey);
-    if (startMapIndex == 0 && endMapIndex == Integer.MAX_VALUE) {
-      return shufflePartition;
-    } else {
-      return String.format("%s[%d,%d)", shufflePartition, startMapIndex, endMapIndex);
-    }
-  }
-
-  @VisibleForTesting
-  PartitionLocation getLocation() {
-    return location;
-  }
-}
diff --git a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java b/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java
deleted file mode 100644
index c044b2b5..00000000
--- a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java
+++ /dev/null
@@ -1,263 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.client.read;
-
-import java.io.IOException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.util.concurrent.Uninterruptibles;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.celeborn.common.CelebornConf;
-import org.apache.celeborn.common.network.buffer.ManagedBuffer;
-import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
-import org.apache.celeborn.common.network.client.TransportClient;
-import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.util.NettyUtils;
-import org.apache.celeborn.common.network.util.TransportConf;
-import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.protocol.TransportModuleConstants;
-import org.apache.celeborn.common.util.Utils;
-
-/**
- * Encapsulate the Partition Location information, so that for the file corresponding to this
- * Partition Location, you can ignore whether there is a retry and whether the Master/Slave switch
- * is performed.
- *
- * <p>Specifically, for a file, we can try maxTries times, and each attempt actually includes
- * attempts to all available copies in a Partition Location. In this way, we can simply take
- * advantage of the ability of multiple copies, and can also ensure that the number of retries for a
- * file will not be too many. Each retry is actually a switch between Master and Slave. Therefore,
- * each retry needs to create a new connection and reopen the file to generate the stream id.
- */
-public class RetryingChunkClient {
-  private static final Logger logger = LoggerFactory.getLogger(RetryingChunkClient.class);
-  private static final ExecutorService executorService =
-      Executors.newCachedThreadPool(NettyUtils.createThreadFactory("fetch-chuck"));
-
-  private final ChunkReceivedCallback callback;
-  private final Replica[] replicas;
-  private final long retryWaitMs;
-  private final int maxTries;
-
-  private volatile int numTries = 0;
-
-  public RetryingChunkClient(
-      CelebornConf conf,
-      String shuffleKey,
-      PartitionLocation location,
-      ChunkReceivedCallback callback,
-      TransportClientFactory clientFactory) {
-    this(conf, shuffleKey, location, callback, clientFactory, 0, Integer.MAX_VALUE);
-  }
-
-  public RetryingChunkClient(
-      CelebornConf conf,
-      String shuffleKey,
-      PartitionLocation location,
-      ChunkReceivedCallback callback,
-      TransportClientFactory clientFactory,
-      int startMapIndex,
-      int endMapIndex) {
-    TransportConf transportConf =
-        Utils.fromCelebornConf(conf, TransportModuleConstants.DATA_MODULE, 0);
-
-    this.callback = callback;
-    this.retryWaitMs = transportConf.ioRetryWaitTimeMs();
-
-    long fetchTimeoutMs = conf.fetchTimeoutMs();
-
-    if (location == null) {
-      throw new IllegalArgumentException("Must contain at least one available PartitionLocation.");
-    } else {
-      Replica main =
-          new Replica(
-              fetchTimeoutMs, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
-      PartitionLocation peerLoc = location.getPeer();
-      if (peerLoc == null) {
-        replicas = new Replica[] {main};
-      } else {
-        Replica peer =
-            new Replica(
-                fetchTimeoutMs, shuffleKey, peerLoc, clientFactory, startMapIndex, endMapIndex);
-        replicas = new Replica[] {main, peer};
-      }
-    }
-
-    this.maxTries = (transportConf.maxIORetries() + 1) * replicas.length;
-  }
-
-  /**
-   * This method should only be called once after RetryingChunkReader is initialized, so it is
-   * assumed that there is no concurrency problem when it is called.
-   *
-   * @return numChunks.
-   */
-  public synchronized int openChunks() throws IOException {
-    int numChunks = -1;
-    Replica currentReplica = null;
-    Exception currentException = null;
-    while (numChunks == -1 && hasRemainingRetries()) {
-      // Only not wait for first request to each replicate.
-      currentReplica = getCurrentReplica();
-      if (numTries >= replicas.length) {
-        logger.info(
-            "Retrying openChunk ({}/{}) for chunk from {} after {} ms.",
-            numTries,
-            maxTries,
-            currentReplica,
-            retryWaitMs);
-        Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
-      }
-      try {
-        currentReplica.getOrOpenStream();
-        numChunks = currentReplica.getNumChunks();
-      } catch (InterruptedException e) {
-        Thread.currentThread().interrupt();
-        throw new IOException(e);
-      } catch (Exception e) {
-        logger.error(
-            "Exception raised while sending open chunks message to " + currentReplica + ".", e);
-        currentException = e;
-        if (shouldRetry(e)) {
-          numTries += 1;
-        } else {
-          break;
-        }
-      }
-    }
-    if (numChunks == -1) {
-      if (currentException != null) {
-        throw new IOException(
-            String.format(
-                "Could not open chunks from %s after %d tries.", currentReplica, numTries),
-            currentException);
-      } else {
-        throw new IOException(
-            String.format(
-                "Could not open chunks from %s after %d tries.", currentReplica, numTries));
-      }
-    }
-    return numChunks;
-  }
-
-  /**
-   * Fetch for a chunk. It can be retried multiple times, so there is no guarantee that the order
-   * will arrive on the server side, nor can it guarantee an orderly return. Therefore, the chunks
-   * should be as orderly as possible when calling.
-   *
-   * @param chunkIndex the index of the chunk to be fetched.
-   */
-  public void fetchChunk(int chunkIndex) {
-    Replica replica;
-    RetryingChunkReceiveCallback callback;
-    synchronized (this) {
-      replica = getCurrentReplica();
-      callback = new RetryingChunkReceiveCallback(numTries);
-    }
-    try {
-      TransportClient client = replica.getOrOpenStream();
-      client.fetchChunk(replica.getStreamId(), chunkIndex, callback);
-    } catch (Exception e) {
-      logger.error(
-          "Exception raised while beginning fetch chunk {}{}.",
-          chunkIndex,
-          numTries > 0 ? " (after " + numTries + " retries)" : "",
-          e);
-
-      if (shouldRetry(e)) {
-        initiateRetry(chunkIndex, callback.currentNumTries);
-      } else {
-        callback.onFailure(chunkIndex, e);
-      }
-    }
-  }
-
-  @VisibleForTesting
-  Replica getCurrentReplica() {
-    int currentReplicaIndex = numTries % replicas.length;
-    return replicas[currentReplicaIndex];
-  }
-
-  @VisibleForTesting
-  int getNumTries() {
-    return numTries;
-  }
-
-  private boolean hasRemainingRetries() {
-    return numTries < maxTries;
-  }
-
-  private synchronized boolean shouldRetry(Throwable e) {
-    boolean isIOException =
-        e instanceof IOException
-            || e instanceof TimeoutException
-            || (e.getCause() != null && e.getCause() instanceof TimeoutException)
-            || (e.getCause() != null && e.getCause() instanceof IOException)
-            || (e instanceof RuntimeException
-                && e.getMessage().startsWith(IOException.class.getName()));
-    return isIOException && hasRemainingRetries();
-  }
-
-  @SuppressWarnings("UnstableApiUsage")
-  private synchronized void initiateRetry(final int chunkIndex, int currentNumTries) {
-    numTries = Math.max(numTries, currentNumTries + 1);
-
-    logger.info(
-        "Retrying fetch ({}/{}) for chunk {} from {} after {} ms.",
-        currentNumTries,
-        maxTries,
-        chunkIndex,
-        getCurrentReplica(),
-        retryWaitMs);
-
-    executorService.submit(
-        () -> {
-          Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
-          fetchChunk(chunkIndex);
-        });
-  }
-
-  private class RetryingChunkReceiveCallback implements ChunkReceivedCallback {
-    final int currentNumTries;
-
-    RetryingChunkReceiveCallback(int currentNumTries) {
-      this.currentNumTries = currentNumTries;
-    }
-
-    @Override
-    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
-      callback.onSuccess(chunkIndex, buffer);
-    }
-
-    @Override
-    public void onFailure(int chunkIndex, Throwable e) {
-      if (shouldRetry(e)) {
-        initiateRetry(chunkIndex, this.currentNumTries);
-      } else {
-        logger.error("Abandon to fetch chunk {} after {} tries.", chunkIndex, this.currentNumTries);
-        callback.onFailure(chunkIndex, e);
-      }
-    }
-  }
-}
diff --git a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
index 54418832..17081751 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
@@ -105,6 +105,8 @@ public abstract class RssInputStream extends InputStream {
 
     private ByteBuf currentChunk;
     private PartitionReader currentReader;
+    private final int fetchChunkMaxRetry;
+    private int fetchChunkRetryCnt = 0;
     private int fileIndex;
     private int position;
     private int limit;
@@ -144,6 +146,8 @@ public abstract class RssInputStream extends InputStream {
 
       decompressor = Decompressor.getDecompressor(conf);
 
+      fetchChunkMaxRetry = conf.fetchMaxRetries();
+
       moveToNextReader();
     }
 
@@ -180,6 +184,9 @@ public abstract class RssInputStream extends InputStream {
         }
         currentLocation = locations[fileIndex];
       }
+
+      fetchChunkRetryCnt = 0;
+
       return currentLocation;
     }
 
@@ -192,7 +199,7 @@ public abstract class RssInputStream extends InputStream {
       if (currentLocation == null) {
         return;
       }
-      currentReader = createReader(currentLocation);
+      currentReader = createReaderWithRetry(currentLocation);
       fileIndex++;
       while (!currentReader.hasNext()) {
         currentReader.close();
@@ -201,13 +208,65 @@ public abstract class RssInputStream extends InputStream {
         if (currentLocation == null) {
           return;
         }
-        currentReader = createReader(currentLocation);
+        currentReader = createReaderWithRetry(currentLocation);
         fileIndex++;
       }
-      currentChunk = currentReader.next();
+      currentChunk = getNextChunk();
+    }
+
+    private PartitionReader createReaderWithRetry(PartitionLocation location) throws IOException {
+      while (fetchChunkRetryCnt < fetchChunkMaxRetry) {
+        try {
+          return createReader(location, fetchChunkRetryCnt, fetchChunkMaxRetry);
+        } catch (Exception e) {
+          fetchChunkRetryCnt++;
+          if (location.getPeer() != null) {
+            location = location.getPeer();
+            logger.warn(
+                "CreatePartitionReader failed {}/{} times, change to peer",
+                fetchChunkRetryCnt,
+                fetchChunkMaxRetry);
+          } else {
+            logger.warn(
+                "CreatePartitionReader failed {}/{} times, retry the same location",
+                fetchChunkRetryCnt,
+                fetchChunkMaxRetry);
+          }
+        }
+      }
+      throw new IOException("createPartitionReader failed!");
     }
 
-    private PartitionReader createReader(PartitionLocation location) throws IOException {
+    private ByteBuf getNextChunk() throws IOException {
+      while (fetchChunkRetryCnt < fetchChunkMaxRetry) {
+        try {
+          return currentReader.next();
+        } catch (Exception e) {
+          fetchChunkRetryCnt++;
+          currentReader.close();
+          if (fetchChunkRetryCnt == fetchChunkMaxRetry) {
+            logger.warn("Fetch chunk fail exceeds max retry {}", fetchChunkRetryCnt);
+            throw new IOException("Fetch chunk failed for " + fetchChunkRetryCnt + " times");
+          } else {
+            if (currentReader.getLocation().getPeer() != null) {
+              logger.warn(
+                  "Fetch chunk failed {}/{} times, change to peer",
+                  fetchChunkRetryCnt,
+                  fetchChunkMaxRetry);
+              currentReader = createReaderWithRetry(currentReader.getLocation().getPeer());
+            } else {
+              logger.warn("Fetch chunk failed {}/{} times", fetchChunkRetryCnt, fetchChunkMaxRetry);
+              currentReader = createReaderWithRetry(currentReader.getLocation());
+            }
+          }
+        }
+      }
+      throw new IOException("Fetch chunk failed!");
+    }
+
+    private PartitionReader createReader(
+        PartitionLocation location, int fetchChunkRetryCnt, int fetchChunkMaxRetry)
+        throws IOException {
       if (location.getPeer() == null) {
         logger.debug("Partition {} has only one partition replica.", location);
       }
@@ -220,7 +279,14 @@ public abstract class RssInputStream extends InputStream {
       if (storageInfo.getType() == StorageInfo.Type.HDD
           || storageInfo.getType() == StorageInfo.Type.SSD) {
         return new WorkerPartitionReader(
-            conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
+            conf,
+            shuffleKey,
+            location,
+            clientFactory,
+            startMapIndex,
+            endMapIndex,
+            fetchChunkRetryCnt,
+            fetchChunkMaxRetry);
       }
       if (storageInfo.getType() == StorageInfo.Type.HDFS) {
         return new DfsPartitionReader(
@@ -287,7 +353,7 @@ public abstract class RssInputStream extends InputStream {
     @Override
     public void close() {
       int locationsCount = locations.length;
-      logger.warn(
+      logger.debug(
           "total location count {} read {} skip {}",
           locationsCount,
           locationsCount - skipCount.sum(),
@@ -310,7 +376,7 @@ public abstract class RssInputStream extends InputStream {
       }
       currentChunk = null;
       if (currentReader.hasNext()) {
-        currentChunk = currentReader.next();
+        currentChunk = getNextChunk();
         return true;
       } else if (fileIndex < locations.length) {
         moveToNextReader();
diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index 01e35b4d..9d154f58 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -18,6 +18,7 @@
 package org.apache.celeborn.client.read;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
@@ -31,35 +32,48 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.network.buffer.ManagedBuffer;
 import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
 import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.Message;
+import org.apache.celeborn.common.network.protocol.OpenStream;
+import org.apache.celeborn.common.network.protocol.StreamHandle;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 
 public class WorkerPartitionReader implements PartitionReader {
   private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class);
-  private final RetryingChunkClient client;
-  private final int numChunks;
+  private PartitionLocation location;
+  private final TransportClient client;
+  private StreamHandle streamHandle;
 
   private int returnedChunks;
   private int chunkIndex;
 
   private final LinkedBlockingQueue<ByteBuf> results;
+  private final ChunkReceivedCallback callback;
 
   private final AtomicReference<IOException> exception = new AtomicReference<>();
   private final int fetchMaxReqsInFlight;
   private boolean closed = false;
 
+  // for test
+  private int fetchChunkRetryCnt;
+  private int fetchChunkMaxRetry;
+  private final boolean testFetch;
+
   WorkerPartitionReader(
       CelebornConf conf,
       String shuffleKey,
       PartitionLocation location,
       TransportClientFactory clientFactory,
       int startMapIndex,
-      int endMapIndex)
+      int endMapIndex,
+      int fetchChunkRetryCnt,
+      int fetchChunkMaxRetry)
       throws IOException {
     fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight();
     results = new LinkedBlockingQueue<>();
     // only add the buffer to results queue if this reader is not closed.
-    ChunkReceivedCallback callback =
+    callback =
         new ChunkReceivedCallback() {
           @Override
           public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
@@ -80,19 +94,31 @@ public class WorkerPartitionReader implements PartitionReader {
             exception.set(new IOException(errorMsg, e));
           }
         };
-    client =
-        new RetryingChunkClient(
-            conf, shuffleKey, location, callback, clientFactory, startMapIndex, endMapIndex);
-    numChunks = client.openChunks();
+    try {
+      client = clientFactory.createClient(location.getHost(), location.getFetchPort());
+    } catch (InterruptedException ie) {
+      throw new IOException("Interrupted when createClient", ie);
+    }
+    OpenStream openBlocks =
+        new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex);
+    long timeoutMs = conf.fetchTimeoutMs();
+    ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs);
+    streamHandle = (StreamHandle) Message.decode(response);
+
+    this.location = location;
+
+    this.fetchChunkRetryCnt = fetchChunkRetryCnt;
+    this.fetchChunkMaxRetry = fetchChunkMaxRetry;
+    testFetch = conf.testFetchFailure();
   }
 
   public boolean hasNext() {
-    return returnedChunks < numChunks;
+    return returnedChunks < streamHandle.numChunks;
   }
 
   public ByteBuf next() throws IOException {
     checkException();
-    if (chunkIndex < numChunks) {
+    if (chunkIndex < streamHandle.numChunks) {
       fetchChunks();
     }
     ByteBuf chunk = null;
@@ -121,12 +147,23 @@ public class WorkerPartitionReader implements PartitionReader {
     results.clear();
   }
 
+  @Override
+  public PartitionLocation getLocation() {
+    return location;
+  }
+
   private void fetchChunks() {
     final int inFlight = chunkIndex - returnedChunks;
     if (inFlight < fetchMaxReqsInFlight) {
-      final int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - chunkIndex);
+      final int toFetch =
+          Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandle.numChunks - chunkIndex);
       for (int i = 0; i < toFetch; i++) {
-        client.fetchChunk(chunkIndex++);
+        if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) {
+          callback.onFailure(chunkIndex, new IOException("Test fetch chunk failure"));
+        } else {
+          client.fetchChunk(streamHandle.streamId, chunkIndex, callback);
+          chunkIndex++;
+        }
       }
     }
   }
diff --git a/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java
deleted file mode 100644
index 0b85349b..00000000
--- a/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java
+++ /dev/null
@@ -1,423 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.client.read;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-import static org.mockito.Mockito.any;
-import static org.mockito.Mockito.anyInt;
-import static org.mockito.Mockito.anyObject;
-import static org.mockito.Mockito.anyString;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.timeout;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.LinkedHashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
-
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Sets;
-import io.netty.channel.Channel;
-import org.junit.Test;
-import org.mockito.stubbing.Answer;
-
-import org.apache.celeborn.common.CelebornConf;
-import org.apache.celeborn.common.network.buffer.ManagedBuffer;
-import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
-import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
-import org.apache.celeborn.common.network.client.TransportClient;
-import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.client.TransportResponseHandler;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
-import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.util.ThreadUtils;
-
-public class RetryingChunkClientSuiteJ {
-
-  private static final int MASTER_RPC_PORT = 1234;
-  private static final int MASTER_PUSH_PORT = 1235;
-  private static final int MASTER_FETCH_PORT = 1236;
-  private static final int MASTER_REPLICATE_PORT = 1237;
-  private static final int SLAVE_RPC_PORT = 4321;
-  private static final int SLAVE_PUSH_PORT = 4322;
-  private static final int SLAVE_FETCH_PORT = 4323;
-  private static final int SLAVE_REPLICATE_PORT = 4324;
-  private static final PartitionLocation masterLocation =
-      new PartitionLocation(
-          0,
-          1,
-          "localhost",
-          MASTER_RPC_PORT,
-          MASTER_PUSH_PORT,
-          MASTER_FETCH_PORT,
-          MASTER_REPLICATE_PORT,
-          PartitionLocation.Mode.MASTER);
-  private static final PartitionLocation slaveLocation =
-      new PartitionLocation(
-          0,
-          1,
-          "localhost",
-          SLAVE_RPC_PORT,
-          SLAVE_PUSH_PORT,
-          SLAVE_FETCH_PORT,
-          SLAVE_REPLICATE_PORT,
-          PartitionLocation.Mode.SLAVE);
-
-  static {
-    masterLocation.setPeer(slaveLocation);
-    slaveLocation.setPeer(masterLocation);
-  }
-
-  ManagedBuffer chunk0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
-  ManagedBuffer chunk1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
-  ManagedBuffer chunk2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
-
-  @Test
-  public void testNoFailures() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(chunk0))
-            .put(1, Arrays.asList(chunk1))
-            .put(2, Arrays.asList(chunk2))
-            .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
-    verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
-    verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testUnrecoverableFailure() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(new RuntimeException("Ouch!")))
-            .put(1, Arrays.asList(chunk1))
-            .put(2, Arrays.asList(chunk2))
-            .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    verify(callback, timeout(5000)).onFailure(eq(0), any());
-    verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
-    verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testDuplicateSuccess() throws IOException, InterruptedException {
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder().put(0, Arrays.asList(chunk0, chunk1)).build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
-    verifyNoMoreInteractions(callback);
-
-    assertEquals(0, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-  }
-
-  @Test
-  public void testSingleIOException() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer =
-        invocation -> {
-          synchronized (signal) {
-            int chunkIndex = (Integer) invocation.getArguments()[0];
-            assertFalse(result.containsKey(chunkIndex));
-            Object value = invocation.getArguments()[1];
-            result.put(chunkIndex, value);
-            signal.release();
-          }
-          return null;
-        };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(new IOException(), chunk0))
-            .put(1, Arrays.asList(chunk1))
-            .put(2, Arrays.asList(chunk2))
-            .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testTwoIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer =
-        invocation -> {
-          synchronized (signal) {
-            int chunkIndex = (Integer) invocation.getArguments()[0];
-            assertFalse(result.containsKey(chunkIndex));
-            Object value = invocation.getArguments()[1];
-            result.put(chunkIndex, value);
-            signal.release();
-          }
-          return null;
-        };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
-            .put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
-            .put(2, Arrays.asList(chunk2))
-            .build();
-    RetryingChunkClient client = performInteractions(interactions, callback);
-
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testThreeIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer =
-        invocation -> {
-          synchronized (signal) {
-            int chunkIndex = (Integer) invocation.getArguments()[0];
-            assertFalse(result.containsKey(chunkIndex));
-            Object value = invocation.getArguments()[1];
-            result.put(chunkIndex, value);
-            signal.release();
-          }
-          return null;
-        };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
-            .put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
-            .put(2, Arrays.asList(new IOException("third ioexception"), chunk2))
-            .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
-    assertEquals(1, client.getNumTries());
-    assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
-    assertEquals(chunk0, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  @Test
-  public void testFailedWithIOExceptions() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer =
-        invocation -> {
-          synchronized (signal) {
-            int chunkIndex = (Integer) invocation.getArguments()[0];
-            assertFalse(result.containsKey(chunkIndex));
-            Object value = invocation.getArguments()[1];
-            result.put(chunkIndex, value);
-            signal.release();
-          }
-          return null;
-        };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    IOException ioe = new IOException("failed exception");
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk0))
-            .put(1, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk1))
-            .put(2, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk2))
-            .build();
-
-    RetryingChunkClient client = performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
-    // Note: this may exceeds the max retries we want, but it doesn't master.
-    assertEquals(4, client.getNumTries());
-    assertEquals(masterLocation, client.getCurrentReplica().getLocation());
-    assertEquals(ioe, result.get(0));
-    assertEquals(ioe, result.get(1));
-    assertEquals(ioe, result.get(2));
-  }
-
-  @Test
-  public void testRetryAndUnrecoverable() throws IOException, InterruptedException {
-    Map<Integer, Object> result = new ConcurrentHashMap<>();
-    ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
-    Semaphore signal = new Semaphore(3);
-    signal.acquire(3);
-
-    Answer<Void> answer =
-        invocation -> {
-          synchronized (signal) {
-            int chunkIndex = (Integer) invocation.getArguments()[0];
-            assertFalse(result.containsKey(chunkIndex));
-            Object value = invocation.getArguments()[1];
-            result.put(chunkIndex, value);
-            signal.release();
-          }
-          return null;
-        };
-    doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
-    doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
-
-    Exception re = new RuntimeException("failed exception");
-    Map<Integer, List<Object>> interactions =
-        ImmutableMap.<Integer, List<Object>>builder()
-            .put(0, Arrays.asList(new IOException("first ioexception"), re, chunk0))
-            .put(1, Arrays.asList(chunk1))
-            .put(2, Arrays.asList(new IOException("second ioexception"), chunk2))
-            .build();
-
-    performInteractions(interactions, callback);
-    while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
-    assertEquals(re, result.get(0));
-    assertEquals(chunk1, result.get(1));
-    assertEquals(chunk2, result.get(2));
-  }
-
-  private static RetryingChunkClient performInteractions(
-      Map<Integer, List<Object>> interactions, ChunkReceivedCallback callback)
-      throws IOException, InterruptedException {
-    CelebornConf conf = new CelebornConf();
-    conf.set("celeborn.data.io.maxRetries", "1");
-    conf.set("celeborn.data.io.retryWait", "0");
-
-    // Contains all chunk ids that are referenced across all interactions.
-    LinkedHashSet<Integer> chunkIds = Sets.newLinkedHashSet(interactions.keySet());
-
-    final TransportClient client = new DummyTransportClient(chunkIds.size(), interactions);
-    final TransportClientFactory clientFactory = mock(TransportClientFactory.class);
-    doAnswer(invocation -> client).when(clientFactory).createClient(anyString(), anyInt());
-
-    RetryingChunkClient retryingChunkClient =
-        new RetryingChunkClient(conf, "test", masterLocation, callback, clientFactory);
-    chunkIds.stream().sorted().forEach(retryingChunkClient::fetchChunk);
-    return retryingChunkClient;
-  }
-
-  private static class DummyTransportClient extends TransportClient {
-
-    private static final Channel channel = mock(Channel.class);
-    private static final TransportResponseHandler handler = mock(TransportResponseHandler.class);
-
-    private final long streamId = new Random().nextInt(Integer.MAX_VALUE) * 1000L;
-    private final int numChunks;
-    private final Map<Integer, List<Object>> interactions;
-    private final Map<Integer, Integer> chunkIdToInterActionIndex;
-
-    private final ScheduledExecutorService schedule =
-        ThreadUtils.newDaemonThreadPoolScheduledExecutor("test-fetch-chunk", 3);
-
-    DummyTransportClient(int numChunks, Map<Integer, List<Object>> interactions) {
-      super(channel, handler);
-      this.numChunks = numChunks;
-      this.interactions = interactions;
-      this.chunkIdToInterActionIndex = new ConcurrentHashMap<>();
-      interactions.keySet().forEach((chunkId) -> chunkIdToInterActionIndex.putIfAbsent(chunkId, 0));
-    }
-
-    @Override
-    public void fetchChunk(long streamId, int chunkId, ChunkReceivedCallback callback) {
-      schedule.schedule(
-          () -> {
-            Object action;
-            List<Object> interaction = interactions.get(chunkId);
-            synchronized (chunkIdToInterActionIndex) {
-              int index = chunkIdToInterActionIndex.get(chunkId);
-              assertTrue(index < interaction.size());
-              action = interaction.get(index);
-              chunkIdToInterActionIndex.put(chunkId, index + 1);
-            }
-
-            if (action instanceof ManagedBuffer) {
-              callback.onSuccess(chunkId, (ManagedBuffer) action);
-            } else if (action instanceof Exception) {
-              callback.onFailure(chunkId, (Exception) action);
-            } else {
-              fail("Can only handle ManagedBuffers and Exceptions, got " + action);
-            }
-          },
-          500,
-          TimeUnit.MILLISECONDS);
-    }
-
-    @Override
-    public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
-      StreamHandle handle = new StreamHandle(streamId, numChunks);
-      return handle.toByteBuffer();
-    }
-
-    @Override
-    public void close() {
-      super.close();
-      schedule.shutdownNow();
-    }
-  }
-}
diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index e5e82db0..4ffc579d 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -21,14 +21,12 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.net.SocketAddress;
 import java.nio.ByteBuffer;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Objects;
 import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
 import com.google.common.util.concurrent.SettableFuture;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelFuture;
@@ -190,7 +188,7 @@ public class TransportClient implements Closeable {
    * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to a
    * specified timeout for a response.
    */
-  public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+  public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) throws IOException {
     final SettableFuture<ByteBuffer> result = SettableFuture.create();
 
     sendRpc(
@@ -213,10 +211,8 @@ public class TransportClient implements Closeable {
 
     try {
       return result.get(timeoutMs, TimeUnit.MILLISECONDS);
-    } catch (ExecutionException e) {
-      throw Throwables.propagate(e.getCause());
     } catch (Exception e) {
-      throw Throwables.propagate(e);
+      throw new IOException("Exception in sendRpcSync", e);
     }
   }
 
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 6767cd3d..16fcdd82 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -642,6 +642,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
   // //////////////////////////////////////////////////////
   def fetchTimeoutMs: Long = get(FETCH_TIMEOUT)
   def fetchMaxReqsInFlight: Int = get(FETCH_MAX_REQS_IN_FLIGHT)
+  def fetchMaxRetries: Int = get(FETCH_MAX_RETRIES)
+  def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
 
   // //////////////////////////////////////////////////////
   //               Shuffle Client Push                   //
@@ -1256,6 +1258,22 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(1024)
 
+  val FETCH_MAX_RETRIES: ConfigEntry[Int] =
+    buildConf("celeborn.fetch.maxRetries")
+      .categories("client")
+      .version("0.2.0")
+      .doc("Max retries of fetch chunk")
+      .intConf
+      .createWithDefault(3)
+
+  val TEST_FETCH_FAILURE: ConfigEntry[Boolean] =
+    buildConf("celeborn.test.fetchFailure")
+      .categories("client")
+      .version("0.2.0")
+      .doc("Wheter to test fetch chunk failure")
+      .booleanConf
+      .createWithDefault(false)
+
   val APPLICATION_HEARTBEAT_TIMEOUT: ConfigEntry[Long] =
     buildConf("celeborn.application.heartbeat.timeout")
       .withAlternative("rss.application.timeout")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 9a5127fd..111452cb 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -23,6 +23,7 @@ license: |
 | celeborn.client.maxRetries | 15 | Max retry times for client to connect master endpoint | 0.2.0 | 
 | celeborn.client.rpc.askTimeout | &lt;value of celeborn.network.timeout&gt; | Timeout for client RPC ask operations. | 0.2.0 | 
 | celeborn.fetch.maxReqsInFlight | 3 | Amount of in-flight chunk fetch request. | 0.2.0 | 
+| celeborn.fetch.maxRetries | 3 | Max retries of fetch chunk | 0.2.0 | 
 | celeborn.fetch.timeout | 120s | Timeout for a task to fetch chunk. | 0.2.0 | 
 | celeborn.master.endpoints | &lt;localhost&gt;:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | 
 | celeborn.push.buffer.initial.size | 8k |  | 0.2.0 | 
@@ -60,6 +61,7 @@ license: |
 | celeborn.slots.reserve.maxRetries | 3 | Max retry times for client to reserve slots. | 0.2.0 | 
 | celeborn.slots.reserve.retryWait | 3s | Wait time before next retry if reserve slots failed. | 0.2.0 | 
 | celeborn.storage.hdfs.dir | &lt;undefined&gt; | HDFS dir configuration for Celeborn to access HDFS. | 0.2.0 | 
+| celeborn.test.fetchFailure | false | Wheter to test fetch chunk failure | 0.2.0 | 
 | celeborn.worker.excluded.checkInterval | 30s | Interval for client to refresh excluded worker list. | 0.2.0 | 
 | celeborn.worker.excluded.expireTimeout | 600s | Timeout time for LifecycleManager to clear reserved excluded worker. | 0.2.0 | 
 <!--end-include-->
diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
index 9f7fa994..935519c9 100644
--- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
+++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
@@ -190,7 +190,7 @@ public class FileWriterSuiteJ {
     return openBlocks.toByteBuffer();
   }
 
-  private void setUpConn(TransportClient client) {
+  private void setUpConn(TransportClient client) throws IOException {
     ByteBuffer resp = client.sendRpcSync(createOpenMessage(), 10000);
     StreamHandle streamHandle = (StreamHandle) Message.decode(resp);
     streamId = streamHandle.streamId;