You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2022/08/02 02:50:59 UTC
[incubator-uniffle] branch master updated: [ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)
This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d07b659 [ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)
d07b659 is described below
commit d07b65908af6676d6a7583a01c3098217abe798c
Author: xianjingfeng <58...@qq.com>
AuthorDate: Tue Aug 2 10:50:54 2022 +0800
[ISSUE-107][IMPROVEMENT] Assign partition again if registerShuffleServers failed (#115)
### What changes were proposed in this pull request?
Solve issue #107 Assign partition again if registerShuffleServers failed
### Why are the changes needed?
If registerShuffleServers failed, task will fail and then application will failed
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Already added
---
.../org/apache/hadoop/mapreduce/RssMRConfig.java | 9 ++
.../hadoop/mapreduce/v2/app/RssMRAppMaster.java | 127 ++++++++++++---------
.../org/apache/spark/shuffle/RssSparkConfig.java | 8 ++
.../apache/spark/shuffle/RssShuffleManager.java | 22 +++-
.../apache/spark/shuffle/RssShuffleManager.java | 31 +++--
.../uniffle/client/util/RssClientConfig.java | 5 +
.../org/apache/uniffle/common/util/RetryUtils.java | 81 +++++++++++++
.../apache/uniffle/common/util/RetryUtilsTest.java | 67 +++++++++++
.../uniffle/test/ShuffleWithRssClientTest.java | 46 ++++++++
9 files changed, 328 insertions(+), 68 deletions(-)
diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index ef47e21..1c98bda 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -151,6 +151,15 @@ public class RssMRConfig {
public static final int RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE =
RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE;
+ public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL =
+ MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL;
+ public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE =
+ RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE;
+ public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES =
+ MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES;
+ public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE =
+ RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE;
+
public static final String RSS_CONF_FILE = "rss_conf.xml";
public static final Set<String> RSS_MANDATORY_CLUSTER_CONF = Sets.newHashSet(
diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index c65f2a2..a3cb270 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -77,7 +77,9 @@ import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.storage.util.StorageType;
public class RssMRAppMaster extends MRAppMaster {
@@ -128,25 +130,9 @@ public class RssMRAppMaster extends MRAppMaster {
}
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
- int requiredAssignmentShuffleServersNum = conf.getInt(
- RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
- RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
- );
-
ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId();
String appId = applicationAttemptId.toString();
- ShuffleAssignmentsInfo response =
- client.getShuffleAssignments(
- appId,
- 0,
- numReduceTasks,
- 1,
- Sets.newHashSet(assignmentTags),
- requiredAssignmentShuffleServersNum
- );
-
- Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = response.getServerToPartitionRanges();
final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
@Override
@@ -157,40 +143,9 @@ public class RssMRAppMaster extends MRAppMaster {
}
}
);
- if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
- return;
- }
-
- long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL,
- RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
- long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
- scheduledExecutorService.scheduleAtFixedRate(
- () -> {
- try {
- client.sendAppHeartbeat(appId, heartbeatTimeout);
- LOG.info("Finish send heartbeat to coordinator and servers");
- } catch (Exception e) {
- LOG.warn("Fail to send heartbeat to coordinator and servers", e);
- }
- },
- heartbeatInterval / 2,
- heartbeatInterval,
- TimeUnit.MILLISECONDS);
JobConf extraConf = new JobConf();
extraConf.clear();
- // write shuffle worker assignments to submit work directory
- // format is as below:
- // mapreduce.rss.assignment.partition.1:server1,server2
- // mapreduce.rss.assignment.partition.2:server3,server4
- // ...
- response.getPartitionToServers().entrySet().forEach(entry -> {
- List<String> servers = Lists.newArrayList();
- for (ShuffleServerInfo server : entry.getValue()) {
- servers.add(server.getHost() + ":" + server.getPort());
- }
- extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ","));
- });
// get remote storage from coordinator if necessary
boolean dynamicConfEnabled = conf.getBoolean(RssMRConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED,
@@ -233,14 +188,80 @@ public class RssMRAppMaster extends MRAppMaster {
}
conf.setInt(MRJobConfig.REDUCE_MAX_ATTEMPTS, originalAttempts + inc);
}
+
+ int requiredAssignmentShuffleServersNum = conf.getInt(
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
+ );
+
+ // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+ long retryInterval = conf.getLong(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL,
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE);
+ int retryTimes = conf.getInt(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES,
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE);
+ ShuffleAssignmentsInfo response;
+ try {
+ response = RetryUtils.retry(() -> {
+ ShuffleAssignmentsInfo shuffleAssignments =
+ client.getShuffleAssignments(
+ appId,
+ 0,
+ numReduceTasks,
+ 1,
+ Sets.newHashSet(assignmentTags),
+ requiredAssignmentShuffleServersNum
+ );
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
+ shuffleAssignments.getServerToPartitionRanges();
+
+ if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+ return null;
+ }
+ LOG.info("Start to register shuffle");
+ long start = System.currentTimeMillis();
+ serverToPartitionRanges.entrySet().forEach(entry -> {
+ client.registerShuffle(
+ entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+ });
+ LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms");
+ return shuffleAssignments;
+ }, retryInterval, retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
- LOG.info("Start to register shuffle");
- long start = System.currentTimeMillis();
- serverToPartitionRanges.entrySet().forEach(entry -> {
- client.registerShuffle(
- entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+ if (response == null) {
+ return;
+ }
+ long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL,
+ RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
+ long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
+ scheduledExecutorService.scheduleAtFixedRate(
+ () -> {
+ try {
+ client.sendAppHeartbeat(appId, heartbeatTimeout);
+ LOG.info("Finish send heartbeat to coordinator and servers");
+ } catch (Exception e) {
+ LOG.warn("Fail to send heartbeat to coordinator and servers", e);
+ }
+ },
+ heartbeatInterval / 2,
+ heartbeatInterval,
+ TimeUnit.MILLISECONDS);
+
+ // write shuffle worker assignments to submit work directory
+ // format is as below:
+ // mapreduce.rss.assignment.partition.1:server1,server2
+ // mapreduce.rss.assignment.partition.2:server3,server4
+ // ...
+ response.getPartitionToServers().entrySet().forEach(entry -> {
+ List<String> servers = Lists.newArrayList();
+ for (ShuffleServerInfo server : entry.getValue()) {
+ servers.add(server.getHost() + ":" + server.getPort());
+ }
+ extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ","));
});
- LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms");
writeExtraConf(conf, extraConf);
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 6b549b1..c546bdc 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -207,6 +207,14 @@ public class RssSparkConfig {
new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER))
.createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE);
+ public static final ConfigEntry<Long> RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = createLongBuilder(
+ new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL))
+ .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE);
+
+ public static final ConfigEntry<Integer> RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = createIntegerBuilder(
+ new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES))
+ .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE);
+
public static final ConfigEntry<String> RSS_COORDINATOR_QUORUM = createStringBuilder(
new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_COORDINATOR_QUORUM)
.doc("Coordinator quorum"))
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index ec84308..753d759 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -60,6 +60,8 @@ import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
@@ -220,13 +222,23 @@ public class RssShuffleManager implements ShuffleManager {
int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
- ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
- appId, shuffleId, dependency.partitioner().numPartitions(),
- partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();
+ // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+ long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ try {
+ partitionToServers = RetryUtils.retry(() -> {
+ ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
+ appId, shuffleId, dependency.partitioner().numPartitions(),
+ partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
+ registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges());
+ return response.getPartitionToServers();
+ }, retryInterval, retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
startHeartbeat();
- registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges());
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size() + "]");
return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, partitionToServers, remoteStorage);
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 80fac99..030e56f 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -64,6 +64,8 @@ import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
@@ -258,17 +260,26 @@ public class RssShuffleManager implements ShuffleManager {
int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
- ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
- id.get(),
- shuffleId,
- dependency.partitioner().numPartitions(),
- 1,
- assignmentTags,
- requiredShuffleServerNumber);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();
-
+ // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+ long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ try {
+ partitionToServers = RetryUtils.retry(() -> {
+ ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
+ id.get(),
+ shuffleId,
+ dependency.partitioner().numPartitions(),
+ 1,
+ assignmentTags,
+ requiredShuffleServerNumber);
+ registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges());
+ return response.getPartitionToServers();
+ }, retryInterval, retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
startHeartbeat();
- registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges());
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size()
+ "], shuffleServerForResult: " + partitionToServers);
diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
index eb6006a..b3247b4 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
@@ -60,6 +60,11 @@ public class RssClientConfig {
public static final String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE = "14m";
// The tags specified by rss client to determine server assignment.
public static final String RSS_CLIENT_ASSIGNMENT_TAGS = "rss.client.assignment.tags";
+
+ public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = "rss.client.assignment.retry.interval";
+ public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE = 65000;
+ public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = "rss.client.assignment.retry.times";
+ public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE = 3;
public static final String RSS_ACCESS_TIMEOUT_MS = "rss.access.timeout.ms";
public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000;
diff --git a/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java
new file mode 100644
index 0000000..603873f
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.Set;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class RetryUtils {
+ private static final Logger LOG = LoggerFactory.getLogger(RetryUtils.class);
+
+ public static <T> T retry(RetryCmd<T> cmd, long intervalMs, int retryTimes) throws Throwable {
+ return retry(cmd, null, intervalMs, retryTimes, null);
+ }
+
+ /**
+ * @param cmd command to execute
+ * @param callBack the callback command executed when the attempt of command fail
+ * @param intervalMs retry interval
+ * @param retryTimes retry times
+ * @param exceptionClasses exception classes which need to be retry, null for all.
+ * @param <T> return type
+ * @return
+ * @throws Throwable
+ */
+ public static <T> T retry(RetryCmd<T> cmd, RetryCallBack callBack, long intervalMs,
+ int retryTimes, Set<Class> exceptionClasses) throws Throwable {
+ int retry = 0;
+ while (true) {
+ try {
+ T ret = cmd.execute();
+ return ret;
+ } catch (Throwable t) {
+ retry++;
+ if ((exceptionClasses != null && !isInstanceOf(exceptionClasses, t)) || retry >= retryTimes) {
+ throw t;
+ } else {
+ LOG.info("Retry due to Throwable, " + t.getClass().getName() + " " + t.getMessage());
+ LOG.info("Waiting " + intervalMs + " milliseconds before next connection attempt.");
+ Thread.sleep(intervalMs);
+ if (callBack != null) {
+ callBack.execute();
+ }
+ }
+ }
+ }
+ }
+
+ private static boolean isInstanceOf(Set<Class> classes, Throwable t) {
+ for (Class c : classes) {
+ if (c.isInstance(t)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public interface RetryCmd<T> {
+ T execute() throws Throwable;
+ }
+
+ public interface RetryCallBack {
+ void execute() throws Throwable;
+ }
+}
diff --git a/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java
new file mode 100644
index 0000000..d119d27
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import com.google.common.collect.Sets;
+import org.apache.uniffle.common.exception.RssException;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RetryUtilsTest {
+ @Test
+ public void testRetry() {
+ AtomicInteger tryTimes = new AtomicInteger();
+ AtomicInteger callbackTime = new AtomicInteger();
+ int maxTryTime = 3;
+ try {
+ RetryUtils.retry(() -> {
+ tryTimes.incrementAndGet();
+ throw new RssException("");
+ }, () -> {
+ callbackTime.incrementAndGet();
+ }, 10, maxTryTime, Sets.newHashSet(RssException.class));
+ } catch (Throwable throwable) {
+ }
+ assertEquals(tryTimes.get(), maxTryTime);
+ assertEquals(callbackTime.get(), maxTryTime - 1);
+
+ tryTimes.set(0);
+ try {
+ RetryUtils.retry(() -> {
+ tryTimes.incrementAndGet();
+ throw new Exception("");
+ }, 10, maxTryTime);
+ } catch (Throwable throwable) {
+ }
+ assertEquals(tryTimes.get(), maxTryTime);
+
+ tryTimes.set(0);
+ try {
+ int ret = RetryUtils.retry(() -> {
+ tryTimes.incrementAndGet();
+ return 1;
+ }, 10, maxTryTime);
+ assertEquals(ret, 1);
+ } catch (Throwable throwable) {
+ }
+ assertEquals(tryTimes.get(), 1);
+ }
+}
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index f2e35c1..ec7d227 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -20,12 +20,15 @@ package org.apache.uniffle.test;
import java.io.File;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.io.Files;
import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RetryUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.BeforeAll;
@@ -39,6 +42,7 @@ import org.apache.uniffle.client.util.ClientType;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.coordinator.CoordinatorConf;
@@ -47,6 +51,7 @@ import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@@ -273,4 +278,45 @@ public class ShuffleWithRssClientTest extends ShuffleReadWriteBase {
.sendCommit(Sets.newHashSet(shuffleServerInfo2), testAppId, 0, 2);
assertFalse(commitResult);
}
+
+ @Test
+ public void testRetryAssgin() throws Throwable {
+ int maxTryTime = shuffleServers.size();
+ AtomicInteger tryTime = new AtomicInteger();
+ String appId = "app-1";
+ RemoteStorageInfo remoteStorage = new RemoteStorageInfo("");
+ ShuffleAssignmentsInfo response = null;
+ ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+ int heartbeatTimeout = shuffleServerConf.getInteger("rss.server.heartbeat.timeout", 65000);
+ int heartbeatInterval = shuffleServerConf.getInteger("rss.server.heartbeat.interval", 1000);
+ Thread.sleep(heartbeatInterval * 2);
+ shuffleWriteClientImpl.registerCoordinators(COORDINATOR_QUORUM);
+ response = RetryUtils.retry(() -> {
+ int currentTryTime = tryTime.incrementAndGet();
+ ShuffleAssignmentsInfo shuffleAssignments = shuffleWriteClientImpl.getShuffleAssignments(appId,
+ 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1);
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
+ shuffleAssignments.getServerToPartitionRanges();
+
+ serverToPartitionRanges.entrySet().forEach(entry -> {
+ if (currentTryTime < maxTryTime) {
+ shuffleServers.forEach((ss) -> {
+ if (ss.getId().equals(entry.getKey().getId())) {
+ try {
+ ss.stopServer();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ });
+ }
+ shuffleWriteClientImpl.registerShuffle(
+ entry.getKey(), appId, 0, entry.getValue(), remoteStorage);
+ });
+ return shuffleAssignments;
+ }, heartbeatTimeout, maxTryTime);
+
+ assertNotNull(response);
+ }
}