You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2022/08/02 02:50:59 UTC

[incubator-uniffle] branch master updated: [ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)

This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d07b659  [ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)
d07b659 is described below

commit d07b65908af6676d6a7583a01c3098217abe798c
Author: xianjingfeng <58...@qq.com>
AuthorDate: Tue Aug 2 10:50:54 2022 +0800

    [ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)
    
    ### What changes were proposed in this pull request?
    Solve issue #107 Assign partition again if registerShuffleServers failed
    
    ### Why are the changes needed?
    If registerShuffleServers failed, task will fail and then application will failed
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Already added
---
 .../org/apache/hadoop/mapreduce/RssMRConfig.java   |   9 ++
 .../hadoop/mapreduce/v2/app/RssMRAppMaster.java    | 127 ++++++++++++---------
 .../org/apache/spark/shuffle/RssSparkConfig.java   |   8 ++
 .../apache/spark/shuffle/RssShuffleManager.java    |  22 +++-
 .../apache/spark/shuffle/RssShuffleManager.java    |  31 +++--
 .../uniffle/client/util/RssClientConfig.java       |   5 +
 .../org/apache/uniffle/common/util/RetryUtils.java |  81 +++++++++++++
 .../apache/uniffle/common/util/RetryUtilsTest.java |  67 +++++++++++
 .../uniffle/test/ShuffleWithRssClientTest.java     |  46 ++++++++
 9 files changed, 328 insertions(+), 68 deletions(-)

diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index ef47e21..1c98bda 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -151,6 +151,15 @@ public class RssMRConfig {
   public static final int RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE =
       RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE;
 
+  public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL =
+          MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL;
+  public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE =
+          RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE;
+  public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES =
+          MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES;
+  public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE =
+          RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE;
+  
   public static final String RSS_CONF_FILE = "rss_conf.xml";
 
   public static final Set<String> RSS_MANDATORY_CLUSTER_CONF = Sets.newHashSet(
diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index c65f2a2..a3cb270 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -77,7 +77,9 @@ import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
 public class RssMRAppMaster extends MRAppMaster {
@@ -128,25 +130,9 @@ public class RssMRAppMaster extends MRAppMaster {
       }
       assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
 
-      int requiredAssignmentShuffleServersNum = conf.getInt(
-          RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
-          RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
-      );
-
       ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId();
       String appId = applicationAttemptId.toString();
 
-      ShuffleAssignmentsInfo response =
-          client.getShuffleAssignments(
-              appId,
-              0,
-              numReduceTasks,
-              1,
-              Sets.newHashSet(assignmentTags),
-              requiredAssignmentShuffleServersNum
-          );
-
-      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = response.getServerToPartitionRanges();
       final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
           new ThreadFactory() {
             @Override
@@ -157,40 +143,9 @@ public class RssMRAppMaster extends MRAppMaster {
             }
           }
       );
-      if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
-        return;
-      }
-
-      long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL,
-          RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
-      long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
-      scheduledExecutorService.scheduleAtFixedRate(
-          () -> {
-            try {
-              client.sendAppHeartbeat(appId, heartbeatTimeout);
-              LOG.info("Finish send heartbeat to coordinator and servers");
-            } catch (Exception e) {
-              LOG.warn("Fail to send heartbeat to coordinator and servers", e);
-            }
-          },
-          heartbeatInterval / 2,
-          heartbeatInterval,
-          TimeUnit.MILLISECONDS);
 
       JobConf extraConf = new JobConf();
       extraConf.clear();
-      // write shuffle worker assignments to submit work directory
-      // format is as below:
-      // mapreduce.rss.assignment.partition.1:server1,server2
-      // mapreduce.rss.assignment.partition.2:server3,server4
-      // ...
-      response.getPartitionToServers().entrySet().forEach(entry -> {
-        List<String> servers = Lists.newArrayList();
-        for (ShuffleServerInfo server : entry.getValue()) {
-          servers.add(server.getHost() + ":" + server.getPort());
-        }
-        extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ","));
-      });
 
       // get remote storage from coordinator if necessary
       boolean dynamicConfEnabled = conf.getBoolean(RssMRConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED,
@@ -233,14 +188,80 @@ public class RssMRAppMaster extends MRAppMaster {
         }
         conf.setInt(MRJobConfig.REDUCE_MAX_ATTEMPTS, originalAttempts + inc);
       }
+      
+      int requiredAssignmentShuffleServersNum = conf.getInt(
+              RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
+              RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
+      );
+      
+      // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+      long retryInterval = conf.getLong(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL,
+              RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE);
+      int retryTimes = conf.getInt(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES,
+              RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE);
+      ShuffleAssignmentsInfo response;
+      try {
+        response = RetryUtils.retry(() -> {
+          ShuffleAssignmentsInfo shuffleAssignments =
+                  client.getShuffleAssignments(
+                          appId,
+                          0,
+                          numReduceTasks,
+                          1,
+                          Sets.newHashSet(assignmentTags),
+                          requiredAssignmentShuffleServersNum
+                  );
+
+          Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
+              shuffleAssignments.getServerToPartitionRanges();
+
+          if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+            return null;
+          }
+          LOG.info("Start to register shuffle");
+          long start = System.currentTimeMillis();
+          serverToPartitionRanges.entrySet().forEach(entry -> {
+            client.registerShuffle(
+                entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+          });
+          LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms");
+          return shuffleAssignments;
+        }, retryInterval, retryTimes);
+      } catch (Throwable throwable) {
+        throw new RssException("registerShuffle failed!", throwable);
+      }
 
-      LOG.info("Start to register shuffle");
-      long start = System.currentTimeMillis();
-      serverToPartitionRanges.entrySet().forEach(entry -> {
-        client.registerShuffle(
-            entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+      if (response == null) {
+        return;
+      }
+      long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL,
+          RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
+      long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
+      scheduledExecutorService.scheduleAtFixedRate(
+          () -> {
+            try {
+              client.sendAppHeartbeat(appId, heartbeatTimeout);
+              LOG.info("Finish send heartbeat to coordinator and servers");
+            } catch (Exception e) {
+              LOG.warn("Fail to send heartbeat to coordinator and servers", e);
+            }
+          },
+          heartbeatInterval / 2,
+          heartbeatInterval,
+          TimeUnit.MILLISECONDS);
+
+      // write shuffle worker assignments to submit work directory
+      // format is as below:
+      // mapreduce.rss.assignment.partition.1:server1,server2
+      // mapreduce.rss.assignment.partition.2:server3,server4
+      // ...
+      response.getPartitionToServers().entrySet().forEach(entry -> {
+        List<String> servers = Lists.newArrayList();
+        for (ShuffleServerInfo server : entry.getValue()) {
+          servers.add(server.getHost() + ":" + server.getPort());
+        }
+        extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ","));
       });
-      LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms");
 
       writeExtraConf(conf, extraConf);
 
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 6b549b1..c546bdc 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -207,6 +207,14 @@ public class RssSparkConfig {
       new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER))
       .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE);
 
+  public static final ConfigEntry<Long> RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = createLongBuilder(
+          new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL))
+          .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE);
+
+  public static final ConfigEntry<Integer> RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = createIntegerBuilder(
+          new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES))
+          .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE);
+  
   public static final ConfigEntry<String> RSS_COORDINATOR_QUORUM = createStringBuilder(
       new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_COORDINATOR_QUORUM)
           .doc("Coordinator quorum"))
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index ec84308..753d759 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -60,6 +60,8 @@ import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 
@@ -220,13 +222,23 @@ public class RssShuffleManager implements ShuffleManager {
 
     int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
 
-    ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
-        appId, shuffleId, dependency.partitioner().numPartitions(),
-        partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();
+    // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+    long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+    int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+    try {
+      partitionToServers = RetryUtils.retry(() -> {
+        ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
+                appId, shuffleId, dependency.partitioner().numPartitions(),
+                partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
+        registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges());
+        return response.getPartitionToServers();
+      }, retryInterval, retryTimes);
+    } catch (Throwable throwable) {
+      throw new RssException("registerShuffle failed!", throwable);
+    }
 
     startHeartbeat();
-    registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges());
 
     LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size() + "]");
     return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, partitionToServers, remoteStorage);
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 80fac99..030e56f 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -64,6 +64,8 @@ import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 
@@ -258,17 +260,26 @@ public class RssShuffleManager implements ShuffleManager {
 
     int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
 
-    ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
-        id.get(),
-        shuffleId,
-        dependency.partitioner().numPartitions(),
-        1,
-        assignmentTags,
-        requiredShuffleServerNumber);
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();
-
+    // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+    long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+    int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+    try {
+      partitionToServers = RetryUtils.retry(() -> {
+        ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
+                id.get(),
+                shuffleId,
+                dependency.partitioner().numPartitions(),
+                1,
+                assignmentTags,
+                requiredShuffleServerNumber);
+        registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges());
+        return response.getPartitionToServers();
+      }, retryInterval, retryTimes);
+    } catch (Throwable throwable) {
+      throw new RssException("registerShuffle failed!", throwable);
+    }
     startHeartbeat();
-    registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges());
 
     LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size()
         + "], shuffleServerForResult: " + partitionToServers);
diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
index eb6006a..b3247b4 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
@@ -60,6 +60,11 @@ public class RssClientConfig {
   public static final String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE = "14m";
   // The tags specified by rss client to determine server assignment.
   public static final String RSS_CLIENT_ASSIGNMENT_TAGS = "rss.client.assignment.tags";
+  
+  public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = "rss.client.assignment.retry.interval";
+  public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE = 65000;
+  public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = "rss.client.assignment.retry.times";
+  public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE = 3;
 
   public static final String RSS_ACCESS_TIMEOUT_MS = "rss.access.timeout.ms";
   public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000;
diff --git a/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java
new file mode 100644
index 0000000..603873f
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.Set;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class RetryUtils {
+  private static final Logger LOG = LoggerFactory.getLogger(RetryUtils.class);
+
+  public static <T> T retry(RetryCmd<T> cmd, long intervalMs, int retryTimes) throws Throwable {
+    return retry(cmd, null, intervalMs, retryTimes, null);
+  }
+
+  /**
+   * @param cmd              command to execute
+   * @param callBack         the callback command executed when the attempt of command fail
+   * @param intervalMs       retry interval
+   * @param retryTimes       retry times
+   * @param exceptionClasses exception classes which need to be retry, null for all.
+   * @param <T>              return type
+   * @return
+   * @throws Throwable
+   */
+  public static <T> T retry(RetryCmd<T> cmd, RetryCallBack callBack, long intervalMs,
+                            int retryTimes, Set<Class> exceptionClasses) throws Throwable {
+    int retry = 0;
+    while (true) {
+      try {
+        T ret = cmd.execute();
+        return ret;
+      } catch (Throwable t) {
+        retry++;
+        if ((exceptionClasses != null && !isInstanceOf(exceptionClasses, t)) || retry >= retryTimes) {
+          throw t;
+        } else {
+          LOG.info("Retry due to Throwable, " + t.getClass().getName() + " " + t.getMessage());
+          LOG.info("Waiting " + intervalMs + " milliseconds before next connection attempt.");
+          Thread.sleep(intervalMs);
+          if (callBack != null) {
+            callBack.execute();
+          }
+        }
+      }
+    }
+  }
+
+  private static boolean isInstanceOf(Set<Class> classes, Throwable t) {
+    for (Class c : classes) {
+      if (c.isInstance(t)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  public interface RetryCmd<T> {
+    T execute() throws Throwable;
+  }
+
+  public interface RetryCallBack {
+    void execute() throws Throwable;
+  }
+}
diff --git a/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java
new file mode 100644
index 0000000..d119d27
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import com.google.common.collect.Sets;
+import org.apache.uniffle.common.exception.RssException;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RetryUtilsTest {
+  @Test
+  public void testRetry() {
+    AtomicInteger tryTimes = new AtomicInteger();
+    AtomicInteger callbackTime = new AtomicInteger();
+    int maxTryTime = 3;
+    try {
+      RetryUtils.retry(() -> {
+        tryTimes.incrementAndGet();
+        throw new RssException("");
+      }, () -> {
+        callbackTime.incrementAndGet();
+      }, 10, maxTryTime, Sets.newHashSet(RssException.class));
+    } catch (Throwable throwable) {
+    }
+    assertEquals(tryTimes.get(), maxTryTime);
+    assertEquals(callbackTime.get(), maxTryTime - 1);
+
+    tryTimes.set(0);
+    try {
+      RetryUtils.retry(() -> {
+        tryTimes.incrementAndGet();
+        throw new Exception("");
+      }, 10, maxTryTime);
+    } catch (Throwable throwable) {
+    }
+    assertEquals(tryTimes.get(), maxTryTime);
+
+    tryTimes.set(0);
+    try {
+      int ret = RetryUtils.retry(() -> {
+        tryTimes.incrementAndGet();
+        return 1;
+      }, 10, maxTryTime);
+      assertEquals(ret, 1);
+    } catch (Throwable throwable) {
+    }
+    assertEquals(tryTimes.get(), 1);
+  }
+}
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index f2e35c1..ec7d227 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -20,12 +20,15 @@ package org.apache.uniffle.test;
 import java.io.File;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.io.Files;
 import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RetryUtils;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.BeforeAll;
@@ -39,6 +42,7 @@ import org.apache.uniffle.client.util.ClientType;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.coordinator.CoordinatorConf;
@@ -47,6 +51,7 @@ import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 
@@ -273,4 +278,45 @@ public class ShuffleWithRssClientTest extends ShuffleReadWriteBase {
         .sendCommit(Sets.newHashSet(shuffleServerInfo2), testAppId, 0, 2);
     assertFalse(commitResult);
   }
+
+  @Test
+  public void testRetryAssgin() throws Throwable {
+    int maxTryTime = shuffleServers.size();
+    AtomicInteger tryTime = new AtomicInteger();
+    String appId = "app-1";
+    RemoteStorageInfo remoteStorage = new RemoteStorageInfo("");
+    ShuffleAssignmentsInfo response = null;
+    ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+    int heartbeatTimeout = shuffleServerConf.getInteger("rss.server.heartbeat.timeout", 65000);
+    int heartbeatInterval = shuffleServerConf.getInteger("rss.server.heartbeat.interval", 1000);
+    Thread.sleep(heartbeatInterval * 2);
+    shuffleWriteClientImpl.registerCoordinators(COORDINATOR_QUORUM);
+    response = RetryUtils.retry(() -> {
+      int currentTryTime = tryTime.incrementAndGet();
+      ShuffleAssignmentsInfo shuffleAssignments = shuffleWriteClientImpl.getShuffleAssignments(appId,
+          1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1);
+
+      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
+          shuffleAssignments.getServerToPartitionRanges();
+
+      serverToPartitionRanges.entrySet().forEach(entry -> {
+        if (currentTryTime < maxTryTime) {
+          shuffleServers.forEach((ss) -> {
+            if (ss.getId().equals(entry.getKey().getId())) {
+              try {
+                ss.stopServer();
+              } catch (Exception e) {
+                e.printStackTrace();
+              }
+            }
+          });
+        }
+        shuffleWriteClientImpl.registerShuffle(
+            entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+      });
+      return shuffleAssignments;
+    }, heartbeatTimeout, maxTryTime);
+
+    assertNotNull(response);
+  }
 }