You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2023/01/31 09:07:02 UTC

[incubator-uniffle] branch master updated: [ISSUE-468] Put unavailable servers to the end of the list when sending shuffle data (#470)

This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new ebbe2db2 [ISSUE-468] Put unavailable servers to the end of the list when sending shuffle data (#470)
ebbe2db2 is described below

commit ebbe2db238cc6a611aea479fab988e23c636c952
Author: xianjingfeng <58...@qq.com>
AuthorDate: Tue Jan 31 17:06:56 2023 +0800

    [ISSUE-468] Put unavailable servers to the end of the list when sending shuffle data (#470)
    
    ### What changes were proposed in this pull request?
    Put unavailable shuffle servers to the end of the server list when sending shuffle data if replica=1.
    
    ### Why are the changes needed?
    
    If we use multiple replicas and the first shuffle server becomes unavailable, sending data will take a lot of time. Because the client will always send to the first shuffle server firstly. #468
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
---
 .../client/impl/ShuffleWriteClientImpl.java        | 63 ++++++++++++++---
 .../client/impl/ShuffleWriteClientImplTest.java    | 78 ++++++++++++++++++++++
 2 files changed, 133 insertions(+), 8 deletions(-)

diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index a784e84c..25f4a101 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -36,6 +36,7 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -105,6 +106,7 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
   private final ExecutorService dataTransferPool;
   private final int unregisterThreadPoolSize;
   private final int unregisterRequestTimeSec;
+  private Set<ShuffleServerInfo> defectiveServers;
 
   public ShuffleWriteClientImpl(
       String clientType,
@@ -133,6 +135,9 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
     this.dataCommitPoolSize = dataCommitPoolSize;
     this.unregisterThreadPoolSize = unregisterThreadPoolSize;
     this.unregisterRequestTimeSec = unregisterRequestTimeSec;
+    if (replica > 1) {
+      defectiveServers = Sets.newConcurrentHashSet();
+    }
   }
 
   private boolean sendShuffleDataAsync(
@@ -170,12 +175,21 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
           if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
             // mark a replica of block that has been sent
             serverToBlockIds.get(ssi).forEach(block -> blockIdsTracker.get(block).incrementAndGet());
+            if (defectiveServers != null) {
+              defectiveServers.remove(ssi);
+            }
             LOG.info("{} successfully.", logMsg);
           } else {
+            if (defectiveServers != null) {
+              defectiveServers.add(ssi);
+            }
             LOG.warn("{}, it failed wth statusCode[{}]", logMsg, response.getStatusCode());
             return false;
           }
         } catch (Exception e) {
+          if (defectiveServers != null) {
+            defectiveServers.add(ssi);
+          }
           LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e);
           return false;
         }
@@ -192,12 +206,27 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
     return result;
   }
 
-  private void genServerToBlocks(ShuffleBlockInfo sbi, List<ShuffleServerInfo> serverList,
+  void genServerToBlocks(
+      ShuffleBlockInfo sbi,
+      List<ShuffleServerInfo> serverList,
+      int replicaNum,
+      List<ShuffleServerInfo> excludeServers,
       Map<ShuffleServerInfo, Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>> serverToBlocks,
-      Map<ShuffleServerInfo, List<Long>> serverToBlockIds) {
+      Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
+      boolean excludeDefectiveServers) {
+    if (replicaNum <= 0) {
+      return;
+    }
     int partitionId = sbi.getPartitionId();
     int shuffleId = sbi.getShuffleId();
+    int assignedNum = 0;
     for (ShuffleServerInfo ssi : serverList) {
+      if (excludeDefectiveServers && replica > 1 && defectiveServers.contains(ssi)) {
+        continue;
+      }
+      if (CollectionUtils.isNotEmpty(excludeServers) && excludeServers.contains(ssi)) {
+        continue;
+      }
       if (!serverToBlockIds.containsKey(ssi)) {
         serverToBlockIds.put(ssi, Lists.newArrayList());
       }
@@ -216,6 +245,18 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
         partitionToBlocks.put(partitionId, Lists.newArrayList());
       }
       partitionToBlocks.get(partitionId).add(sbi);
+      if (excludeServers != null) {
+        excludeServers.add(ssi);
+      }
+      assignedNum++;
+      if (assignedNum >= replicaNum) {
+        break;
+      }
+    }
+
+    if (assignedNum < replicaNum && excludeDefectiveServers) {
+      genServerToBlocks(sbi, serverList, replicaNum - assignedNum,
+          excludeServers, serverToBlocks, serverToBlockIds, false);
     }
   }
 
@@ -247,14 +288,15 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
     for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
       List<ShuffleServerInfo> allServers = sbi.getShuffleServerInfos();
       if (replicaSkipEnabled) {
-        genServerToBlocks(sbi, allServers.subList(0, replicaWrite),
-            primaryServerToBlocks, primaryServerToBlockIds);
-        genServerToBlocks(sbi, allServers.subList(replicaWrite, replica),
-            secondaryServerToBlocks, secondaryServerToBlockIds);
+        List<ShuffleServerInfo> excludeServers = new ArrayList<>();
+        genServerToBlocks(sbi, allServers, replicaWrite, excludeServers,
+            primaryServerToBlocks, primaryServerToBlockIds, true);
+        genServerToBlocks(sbi, allServers,replica - replicaWrite,
+            excludeServers, secondaryServerToBlocks, secondaryServerToBlockIds, false);
       } else {
         // When replicaSkip is disabled, we send data to all replicas within one round.
-        genServerToBlocks(sbi, allServers,
-            primaryServerToBlocks, primaryServerToBlockIds);
+        genServerToBlocks(sbi, allServers, allServers.size(),
+            null, primaryServerToBlocks, primaryServerToBlockIds, false);
       }
     }
 
@@ -756,6 +798,11 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
     return ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, shuffleServerInfo);
   }
 
+  @VisibleForTesting
+  Set<ShuffleServerInfo> getDefectiveServers() {
+    return defectiveServers;
+  }
+
   void addShuffleServer(String appId, int shuffleId, ShuffleServerInfo serverInfo) {
     Map<Integer, Set<ShuffleServerInfo>> appServerMap = shuffleServerInfoMap.get(appId);
     if (appServerMap == null) {
diff --git a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index 47cb1e1f..3693a941 100644
--- a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -17,10 +17,12 @@
 
 package org.apache.uniffle.client.impl;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 import org.mockito.Mockito;
@@ -104,4 +106,80 @@ public class ShuffleWriteClientImplTest {
     shuffleWriteClient.unregisterShuffle(appId1, 1);
     assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId1).size());
   }
+
+  @Test
+  public void testSendDataWithDefectiveServers() {
+    ShuffleWriteClientImpl shuffleWriteClient =
+        new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 3, 2, 2, true, 1, 1, 10, 10);
+    ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
+    ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
+    doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
+    when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
+        new RssSendShuffleDataResponse(ResponseStatusCode.NO_BUFFER),
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));
+
+    String appId = "testSendDataWithDefectiveServers_appId";
+    ShuffleServerInfo ssi1 = new ShuffleServerInfo("127.0.0.1", 0);
+    ShuffleServerInfo ssi2 = new ShuffleServerInfo("127.0.0.1", 1);
+    ShuffleServerInfo ssi3 = new ShuffleServerInfo("127.0.0.1", 2);
+    List<ShuffleServerInfo> shuffleServerInfoList =
+        Lists.newArrayList(ssi1, ssi2, ssi3);
+    List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList(new ShuffleBlockInfo(
+        0, 0, 10, 10, 10, new byte[]{1}, shuffleServerInfoList, 10, 100, 0));
+    SendShuffleDataResult result = spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
+    assertEquals(0, result.getFailedBlockIds().size());
+
+    // Send data for the second time, the first shuffle server will be moved to the last.
+    when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));
+    List<ShuffleServerInfo> excludeServers = new ArrayList<>();
+    spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
+        2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
+    assertEquals(2, excludeServers.size());
+    assertEquals(ssi2, excludeServers.get(0));
+    assertEquals(ssi3, excludeServers.get(1));
+    spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
+        1, excludeServers, Maps.newHashMap(), Maps.newHashMap(), false);
+    assertEquals(3, excludeServers.size());
+    assertEquals(ssi1, excludeServers.get(2));
+    result = spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
+    assertEquals(0, result.getFailedBlockIds().size());
+
+    // Send data for the third time, the first server will be removed from the defectiveServers
+    // and the second server will be added to the defectiveServers.
+    when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
+        new RssSendShuffleDataResponse(ResponseStatusCode.NO_BUFFER),
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
+        new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));
+    List<ShuffleServerInfo> shuffleServerInfoList2 = Lists.newArrayList(ssi2, ssi1, ssi3);
+    List<ShuffleBlockInfo> shuffleBlockInfoList2 = Lists.newArrayList(new ShuffleBlockInfo(0, 0, 10, 10, 10,
+        new byte[]{1}, shuffleServerInfoList2, 10, 100, 0));
+    result = spyClient.sendShuffleData(appId, shuffleBlockInfoList2, () -> false);
+    assertEquals(0, result.getFailedBlockIds().size());
+    assertEquals(1, spyClient.getDefectiveServers().size());
+    assertEquals(ssi2, spyClient.getDefectiveServers().toArray()[0]);
+    excludeServers = new ArrayList<>();
+    spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
+        2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
+    assertEquals(2, excludeServers.size());
+    assertEquals(ssi1, excludeServers.get(0));
+    assertEquals(ssi3, excludeServers.get(1));
+    spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
+        1, excludeServers, Maps.newHashMap(), Maps.newHashMap(), false);
+    assertEquals(3, excludeServers.size());
+    assertEquals(ssi2, excludeServers.get(2));
+
+    // Check whether it is normal when two shuffle servers in defectiveServers
+    spyClient.getDefectiveServers().add(ssi1);
+    assertEquals(2, spyClient.getDefectiveServers().size());
+    excludeServers = new ArrayList<>();
+    spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
+        2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
+    assertEquals(2, excludeServers.size());
+    assertEquals(ssi3, excludeServers.get(0));
+    assertEquals(ssi1, excludeServers.get(1));
+  }
+
 }