You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2022/11/16 13:46:24 UTC

[incubator-celeborn] branch main updated: [CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new fb6d1de1 [CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
fb6d1de1 is described below

commit fb6d1de108ffdc871f38431f049ab76f96787c27
Author: Shuang <lv...@alibaba-inc.com>
AuthorDate: Wed Nov 16 21:46:19 2022 +0800

    [CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 56 +++++++++++++-
 .../apache/celeborn/client/LifecycleManager.scala  | 63 ++++++++++++++--
 .../common/protocol/PartitionLocation.java         | 15 +++-
 .../celeborn/common/util/PackedPartitionId.java    | 58 +++++++++++++++
 common/src/main/proto/TransportMessages.proto      |  9 +++
 .../common/protocol/message/ControlMessages.scala  | 18 +++++
 .../common/protocol/PartitionLocationSuiteJ.java   | 14 +++-
 .../common/util/PackedPartitionIdSuiteJ.java       | 53 +++++++++++++
 .../celeborn/tests/client/ShuffleClientSuite.scala | 87 ++++++++++++++++++++++
 9 files changed, 359 insertions(+), 14 deletions(-)

diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 9cecb140..1283059d 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.client;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.*;
+import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
@@ -27,6 +28,7 @@ import java.util.concurrent.TimeUnit;
 import scala.reflect.ClassTag;
 import scala.reflect.ClassTag$;
 
+import com.google.common.annotations.VisibleForTesting;
 import io.netty.buffer.CompositeByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelFuture;
@@ -56,6 +58,7 @@ import org.apache.celeborn.common.rpc.RpcAddress;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.rpc.RpcEnv;
 import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.PackedPartitionId;
 import org.apache.celeborn.common.util.PbSerDeUtils;
 import org.apache.celeborn.common.util.ThreadUtils;
 import org.apache.celeborn.common.util.Utils;
@@ -257,13 +260,58 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
       String appId, int shuffleId, int numMappers, int numPartitions) {
+    return registerShuffleInternal(
+        shuffleId,
+        numMappers,
+        numMappers,
+        () ->
+            driverRssMetaService.askSync(
+                RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
+                ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+  }
+
+  @VisibleForTesting
+  public PartitionLocation registerMapPartitionTask(
+      String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
+    int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId);
+    logger.info(
+        "register mapPartitionTask, mapId: {}, attemptId: {}, partitionId: {}",
+        mapId,
+        attemptId,
+        partitionId);
+    if (attemptId == 0) {
+      return registerMapPartitionTaskWithFirstAttempt(
+          appId, shuffleId, numMappers, mapId, attemptId, partitionId);
+    }
+
+    // TODO
+    throw new UnsupportedOperationException("can not register shuffle task with attempt beyond 0");
+  }
+
+  private PartitionLocation registerMapPartitionTaskWithFirstAttempt(
+      String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) {
+    ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap =
+        registerShuffleInternal(
+            shuffleId,
+            numMappers,
+            numMappers,
+            () ->
+                driverRssMetaService.askSync(
+                    RegisterMapPartitionTask$.MODULE$.apply(
+                        appId, shuffleId, numMappers, mapId, attemptId, partitionId),
+                    ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+    return partitionLocationMap.get(partitionId);
+  }
+
+  private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
+      int shuffleId,
+      int numMappers,
+      int numPartitions,
+      Callable<PbRegisterShuffleResponse> callable) {
     int numRetries = registerShuffleMaxRetries;
     while (numRetries > 0) {
       try {
-        PbRegisterShuffleResponse response =
-            driverRssMetaService.askSync(
-                RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
-                ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class));
+        PbRegisterShuffleResponse response = callable.call();
         StatusCode respStatus = Utils.toStatusCode(response.getStatus());
         if (StatusCode.SUCCESS.equals(respStatus)) {
           ConcurrentHashMap<Integer, PartitionLocation> result = new ConcurrentHashMap<>();
diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 0dce8391..7a854f33 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Random
 
+import com.google.common.annotations.VisibleForTesting
 import com.google.common.cache.{Cache, CacheBuilder}
 import org.roaringbitmap.RoaringBitmap
 
@@ -53,7 +54,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
   private val pushReplicateEnabled = conf.pushReplicateEnabled
   private val partitionSplitThreshold = conf.partitionSplitThreshold
   private val partitionSplitMode = conf.partitionSplitMode
-  private val partitionType = conf.shufflePartitionType
+  // shuffle id -> partition type
+  private val shufflePartitionType = new ConcurrentHashMap[Int, PartitionType]()
   private val rangeReadFilter = conf.shuffleRangeReadFilterEnabled
   private val unregisterShuffleTime = new ConcurrentHashMap[Int, Long]()
   private val stageEndTimeout = conf.pushStageEndTimeout
@@ -83,7 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
     .maximumSize(rpcCacheSize)
     .build().asInstanceOf[Cache[Int, ByteBuffer]]
 
-  private def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
+  @VisibleForTesting
+  def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
     shuffleAllocatedWorkers.get(shuffleId)
 
   val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
@@ -293,6 +296,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
     rpcEnv.address.port
   }
 
+  def getPartitionType(shuffleId: Int): PartitionType = {
+    shufflePartitionType.getOrDefault(shuffleId, conf.shufflePartitionType)
+  }
+
   override def receive: PartialFunction[Any, Unit] = {
     case RemoveExpiredShuffle =>
       removeExpiredShuffle()
@@ -319,6 +326,22 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
         s"$applicationId, $shuffleId, $numMappers, $numPartitions.")
       handleRegisterShuffle(context, applicationId, shuffleId, numMappers, numPartitions)
 
+    case pb: PbRegisterMapPartitionTask =>
+      val applicationId = pb.getApplicationId
+      val shuffleId = pb.getShuffleId
+      val numMappers = pb.getNumMappers
+      val mapId = pb.getMapId
+      val attemptId = pb.getAttemptId
+      val partitionId = pb.getPartitionId
+      logDebug(s"Received Register map partition task request, " +
+        s"$applicationId, $shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
+      handleRegisterMapPartitionTask(
+        context,
+        applicationId,
+        shuffleId,
+        numMappers,
+        partitionId)
+
     case pb: PbRevive =>
       val applicationId = pb.getApplicationId
       val shuffleId = pb.getShuffleId
@@ -379,6 +402,33 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       shuffleId: Int,
       numMappers: Int,
       numReducers: Int): Unit = {
+    handleOfferAndReserveSlots(context, applicationId, shuffleId, numMappers, numReducers)
+  }
+
+  private def handleRegisterMapPartitionTask(
+      context: RpcCallContext,
+      applicationId: String,
+      shuffleId: Int,
+      numMappers: Int,
+      partitionId: Int): Unit = {
+    shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
+    handleOfferAndReserveSlots(
+      context,
+      applicationId,
+      shuffleId,
+      numMappers,
+      numMappers,
+      partitionId)
+  }
+
+  private def handleOfferAndReserveSlots(
+      context: RpcCallContext,
+      applicationId: String,
+      shuffleId: Int,
+      numMappers: Int,
+      numReducers: Int,
+      partitionId: Int = -1): Unit = {
+    val partitionType = getPartitionType(shuffleId)
     registeringShuffleRequest.synchronized {
       if (registeringShuffleRequest.containsKey(shuffleId)) {
         // If same request already exists in the registering request list for the same shuffle,
@@ -394,7 +444,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
             .values()
             .asScala
             .flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
-            .filter(_.getEpoch == 0)
+            .filter(p =>
+              (partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
             .toArray
           context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
           return
@@ -513,7 +564,9 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
 
       // Fifth, reply the allocated partition location to ShuffleClient.
       logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
-      val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
+      val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).filter(p =>
+        partitionType ==
+          PartitionType.REDUCE || (partitionType == PartitionType.MAP && p.getId == partitionId)).toArray
       reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
     }
   }
@@ -1123,7 +1176,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
             slaveLocations,
             partitionSplitThreshold,
             partitionSplitMode,
-            partitionType,
+            getPartitionType(shuffleId),
             rangeReadFilter,
             userIdentifier))
         if (res.status.equals(StatusCode.SUCCESS)) {
diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
index 2967bd5c..dccf8783 100644
--- a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
+++ b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
@@ -22,6 +22,7 @@ import java.io.Serializable;
 import org.roaringbitmap.RoaringBitmap;
 
 import org.apache.celeborn.common.meta.WorkerInfo;
+import org.apache.celeborn.common.util.PackedPartitionId;
 
 public class PartitionLocation implements Serializable {
   public enum Mode {
@@ -277,9 +278,13 @@ public class PartitionLocation implements Serializable {
       peerAddr = peer.hostAndPorts();
     }
     return "PartitionLocation["
-        + "\n  id-epoch:"
+        + "\n  id(rawId-attemptId)-epoch:"
         + id
+        + "("
+        + getRawId()
         + "-"
+        + getAttemptId()
+        + ")-"
         + epoch
         + "\n  host-rpcPort-pushPort-fetchPort-replicatePort:"
         + host
@@ -313,4 +318,12 @@ public class PartitionLocation implements Serializable {
   public void setMapIdBitMap(RoaringBitmap mapIdBitMap) {
     this.mapIdBitMap = mapIdBitMap;
   }
+
+  public int getRawId() {
+    return PackedPartitionId.getRawPartitionId(id);
+  }
+
+  public int getAttemptId() {
+    return PackedPartitionId.getAttemptId(id);
+  }
 }
diff --git a/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java b/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java
new file mode 100644
index 00000000..00163f5a
--- /dev/null
+++ b/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Pack for encode/decode id of partition Location for id of partitionLocation attemptId
+ * raw_partitionId <br>
+ * (upper 8 bits = attemptId) (lower 24 bits = raw id) <br>
+ * (0000 0000) (0000 0000 0000 0000 0000 0000)<br>
+ *
+ * @see org.apache.celeborn.common.protocol.PartitionLocation#id
+ */
+public class PackedPartitionId {
+
+  /** The maximum partition identifier that can be encoded. Note that partition ids start from 0. */
+  static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+  /** The maximum partition attempt id that can be encoded. Note that attempt ids start from 0. */
+  static final int MAXIMUM_ATTEMPT_ID = (1 << 8) - 1; // 255
+
+  static final int MASK_INT_LOWER_24_BITS = (int) (1L << 24) - 1;
+
+  public static int packedPartitionId(int partitionRawId, int attemptId) {
+    Preconditions.checkArgument(
+        partitionRawId <= MAXIMUM_PARTITION_ID,
+        "packedPartitionId called with invalid partitionRawId: " + partitionRawId);
+    Preconditions.checkArgument(
+        attemptId <= MAXIMUM_ATTEMPT_ID,
+        "packedPartitionId called with invalid attemptId: " + attemptId);
+
+    return (attemptId << 24) | partitionRawId;
+  }
+
+  public static int getRawPartitionId(int packedPartitionId) {
+    return packedPartitionId & MASK_INT_LOWER_24_BITS;
+  }
+
+  public static int getAttemptId(int packedPartitionId) {
+    return packedPartitionId >>> 24;
+  }
+}
diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto
index c6ba510a..96ef54d9 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -155,6 +155,15 @@ message PbRegisterShuffle {
   int32 numPartitions = 4;
 }
 
+message PbRegisterMapPartitionTask {
+  string applicationId = 1;
+  int32 shuffleId = 2;
+  int32 numMappers = 3;
+  int32 mapId = 4;
+  int32 attemptId = 5;
+  int32 partitionId = 6;
+}
+
 message PbRegisterShuffleResponse {
   int32 status = 1;
   repeated PbPartitionLocation partitionLocations = 2;
diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 52893e6b..b92a996c 100644
--- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -130,6 +130,24 @@ object ControlMessages extends Logging {
         .build()
   }
 
+  object RegisterMapPartitionTask {
+    def apply(
+        appId: String,
+        shuffleId: Int,
+        numMappers: Int,
+        mapId: Int,
+        attemptId: Int,
+        partitionId: Int): PbRegisterMapPartitionTask =
+      PbRegisterMapPartitionTask.newBuilder()
+        .setApplicationId(appId)
+        .setShuffleId(shuffleId)
+        .setNumMappers(numMappers)
+        .setMapId(mapId)
+        .setAttemptId(attemptId)
+        .setPartitionId(partitionId)
+        .build()
+  }
+
   object RegisterShuffleResponse {
     def apply(
         status: StatusCode,
diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
index 1d120ab9..010f5afc 100644
--- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
+++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
@@ -20,6 +20,8 @@ package org.apache.celeborn.common.protocol;
 import org.junit.Test;
 import org.roaringbitmap.RoaringBitmap;
 
+import org.apache.celeborn.common.util.PackedPartitionId;
+
 public class PartitionLocationSuiteJ {
 
   private final int partitionId = 0;
@@ -183,9 +185,13 @@ public class PartitionLocationSuiteJ {
     bitmap.add(1);
     bitmap.add(2);
     bitmap.add(3);
+
+    int attemptId = 10;
+    int rawPartitionId = 1000;
+    int newPartitionId = PackedPartitionId.packedPartitionId(rawPartitionId, attemptId);
     PartitionLocation location3 =
         new PartitionLocation(
-            partitionId,
+            newPartitionId,
             epoch,
             host,
             rpcPort,
@@ -199,7 +205,7 @@ public class PartitionLocationSuiteJ {
 
     String exp1 =
         "PartitionLocation[\n"
-            + "  id-epoch:0-0\n"
+            + "  id(rawId-attemptId)-epoch:0(0-0)-0\n"
             + "  host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
             + "  mode:MASTER\n"
             + "  peer:(empty)\n"
@@ -207,7 +213,7 @@ public class PartitionLocationSuiteJ {
             + "  mapIdBitMap:{}]";
     String exp2 =
         "PartitionLocation[\n"
-            + "  id-epoch:0-0\n"
+            + "  id(rawId-attemptId)-epoch:0(0-0)-0\n"
             + "  host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
             + "  mode:MASTER\n"
             + "  peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
@@ -215,7 +221,7 @@ public class PartitionLocationSuiteJ {
             + "  mapIdBitMap:{}]";
     String exp3 =
         "PartitionLocation[\n"
-            + "  id-epoch:0-0\n"
+            + "  id(rawId-attemptId)-epoch:167773160(1000-10)-0\n"
             + "  host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
             + "  mode:MASTER\n"
             + "  peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
diff --git a/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java
new file mode 100644
index 00000000..8ed0bc30
--- /dev/null
+++ b/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class PackedPartitionIdSuiteJ {
+
+  @Test
+  public void testNormalPackedPartitionId() {
+    assertTest(0, 0);
+    assertTest(555, 1);
+    assertTest(888, 1);
+    assertTest(10001, 100);
+
+    // testUseMaxPartitionId or MaxAttemptId
+    assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, 11);
+    assertTest(100, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
+    assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testAttemptIdGreaterThanMaximumAttemptId() {
+    PackedPartitionId.packedPartitionId(0, PackedPartitionId.MAXIMUM_ATTEMPT_ID + 1);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testPartitionIdGreaterThanMaximumPartitionId() {
+    PackedPartitionId.packedPartitionId(PackedPartitionId.MAXIMUM_PARTITION_ID + 1, 1);
+  }
+
+  private void assertTest(int partitionRawId, int attemptId) {
+    int packedPartitionId = PackedPartitionId.packedPartitionId(partitionRawId, attemptId);
+    Assert.assertTrue(partitionRawId == PackedPartitionId.getRawPartitionId(packedPartitionId));
+    Assert.assertTrue(attemptId == PackedPartitionId.getAttemptId(packedPartitionId));
+  }
+}
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
new file mode 100644
index 00000000..e4b00d17
--- /dev/null
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.client
+
+import scala.collection.JavaConverters._
+import scala.language.implicitConversions
+
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClient, ShuffleClientImpl}
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.util.PackedPartitionId
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class ShuffleClientSuite extends AnyFunSuite with MiniClusterFeature
+  with BeforeAndAfterAll {
+  val masterPort = 19097
+  val APP = "app-1"
+  var shuffleClient: ShuffleClientImpl = _
+  var lifecycleManager: LifecycleManager = _
+
+  override def beforeAll(): Unit = {
+    val masterConf = Map(
+      "celeborn.master.host" -> "localhost",
+      "celeborn.master.port" -> masterPort.toString)
+    val workerConf = Map(
+      "celeborn.master.endpoints" -> s"localhost:$masterPort")
+    setUpMiniCluster(masterConf, workerConf)
+
+    val clientConf = new CelebornConf()
+      .set("celeborn.master.endpoints", s"localhost:$masterPort")
+      .set("celeborn.push.replicate.enabled", "true")
+      .set("celeborn.push.buffer.size", "256K")
+    lifecycleManager = new LifecycleManager(APP, clientConf)
+    shuffleClient = new ShuffleClientImpl(clientConf, UserIdentifier("mock", "mock"))
+    shuffleClient.setupMetaServiceRef(lifecycleManager.self)
+  }
+
+  test(s"test register map partition task with first attemptId") {
+    val shuffleId = 1
+    val numMappers = 8
+    val mapId = 1
+    val attemptId = 0
+    var location =
+      shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
+    Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
+
+    // retry register
+    location = shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
+    Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
+
+    // another mapId
+    location =
+      shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId)
+    Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId + 1, attemptId))
+
+    // offer and reserve all slots
+    val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
+    val count =
+      partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
+    Assert.assertEquals(count, numMappers)
+  }
+
+  override def afterAll(): Unit = {
+    // TODO refactor MiniCluster later
+    println("test done")
+    sys.exit(0)
+  }
+}