You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/01/13 08:45:31 UTC

[incubator-celeborn] branch main updated: [CELEBORN-197] in mappartition, check transportClient whether changed while sending messages (#1145)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 1836fe18 [CELEBORN-197] in mappartition, check transportClient whether changed while sending messages (#1145)
1836fe18 is described below

commit 1836fe187b13deb132dd6919aaa6163d66e35bed
Author: zhongqiangczq <96...@users.noreply.github.com>
AuthorDate: Fri Jan 13 16:45:26 2023 +0800

    [CELEBORN-197] in mappartition, check transportClient whether changed while sending messages (#1145)
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 49 ++++++++++++++++------
 1 file changed, 37 insertions(+), 12 deletions(-)

diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 7fc438a4..0676d086 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -134,6 +134,8 @@ public class ShuffleClientImpl extends ShuffleClient {
   // key: shuffleId
   private final Map<Integer, ReduceFileGroups> reduceFileGroupsMap = new ConcurrentHashMap<>();
 
+  private TransportClient currentClient;
+
   public ShuffleClientImpl(CelebornConf conf, UserIdentifier userIdentifier) {
     super();
     this.conf = conf;
@@ -1523,8 +1525,7 @@ public class ShuffleClientImpl extends ShuffleClient {
         };
     // do push data
     try {
-      TransportClient client =
-          dataClientFactory.createClient(location.getHost(), location.getPushPort(), partitionId);
+      TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
       ChannelFuture future = client.pushData(pushData, callback);
       pushState.pushStarted(nextBatchId, future, callback);
     } catch (Exception e) {
@@ -1534,6 +1535,23 @@ public class ShuffleClientImpl extends ShuffleClient {
     return totalLength;
   }
 
+  private TransportClient createClientWaitingInFlightRequest(
+      PartitionLocation location, String mapKey, PushState pushState)
+      throws IOException, InterruptedException {
+    TransportClient client =
+        dataClientFactory.createClient(
+            location.getHost(), location.getPushPort(), location.getId());
+    if (currentClient != client) {
+      // makesure that messages have been sent by old client, in order to keep receiving data
+      // orderly
+      if (currentClient != null) {
+        limitMaxInFlight(mapKey, pushState, 0);
+      }
+      currentClient = client;
+    }
+    return currentClient;
+  }
+
   @Override
   public void pushDataHandShake(
       String applicationId,
@@ -1544,10 +1562,13 @@ public class ShuffleClientImpl extends ShuffleClient {
       int bufferSize,
       PartitionLocation location)
       throws IOException {
+    final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
+    final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
     sendMessageInternal(
         shuffleId,
         mapId,
         attemptId,
+        pushState,
         () -> {
           String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
           logger.info(
@@ -1556,8 +1577,7 @@ public class ShuffleClientImpl extends ShuffleClient {
               attemptId,
               location.getUniqueId());
           logger.debug("pushDataHandShake location:{}", location.toString());
-          TransportClient client =
-              dataClientFactory.createClient(location.getHost(), location.getPushPort());
+          TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
           PushDataHandShake handShake =
               new PushDataHandShake(
                   MASTER_MODE,
@@ -1581,10 +1601,13 @@ public class ShuffleClientImpl extends ShuffleClient {
       int currentRegionIdx,
       boolean isBroadcast)
       throws IOException {
+    final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
+    final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
     return sendMessageInternal(
         shuffleId,
         mapId,
         attemptId,
+        pushState,
         () -> {
           String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
           logger.info(
@@ -1593,8 +1616,7 @@ public class ShuffleClientImpl extends ShuffleClient {
               attemptId,
               location.getUniqueId());
           logger.debug("regionStart location:{}", location.toString());
-          TransportClient client =
-              dataClientFactory.createClient(location.getHost(), location.getPushPort());
+          TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
           RegionStart regionStart =
               new RegionStart(
                   MASTER_MODE,
@@ -1626,7 +1648,6 @@ public class ShuffleClientImpl extends ShuffleClient {
             if (StatusCode.SUCCESS.equals(respStatus)) {
               return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
             } else if (StatusCode.MAP_ENDED.equals(respStatus)) {
-              final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
               mapperEndMap
                   .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
                   .add(mapKey);
@@ -1649,10 +1670,13 @@ public class ShuffleClientImpl extends ShuffleClient {
   public void regionFinish(
       String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
       throws IOException {
+    final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
+    final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
     sendMessageInternal(
         shuffleId,
         mapId,
         attemptId,
+        pushState,
         () -> {
           final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
           logger.info(
@@ -1661,8 +1685,7 @@ public class ShuffleClientImpl extends ShuffleClient {
               attemptId,
               location.getUniqueId());
           logger.debug("regionFinish location:{}", location.toString());
-          TransportClient client =
-              dataClientFactory.createClient(location.getHost(), location.getPushPort());
+          TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
           RegionFinish regionFinish =
               new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
           client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
@@ -1671,9 +1694,12 @@ public class ShuffleClientImpl extends ShuffleClient {
   }
 
   private <R> R sendMessageInternal(
-      int shuffleId, int mapId, int attemptId, ThrowingExceptionSupplier<R, Exception> supplier)
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      PushState pushState,
+      ThrowingExceptionSupplier<R, Exception> supplier)
       throws IOException {
-    PushState pushState = null;
     int batchId = 0;
     try {
       // mapKey
@@ -1687,7 +1713,6 @@ public class ShuffleClientImpl extends ShuffleClient {
             attemptId);
         return null;
       }
-      pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
       // force data has been send
       limitMaxInFlight(mapKey, pushState, 0);