You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2022/11/16 13:46:24 UTC
[incubator-celeborn] branch main updated: [CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new fb6d1de1 [CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
fb6d1de1 is described below
commit fb6d1de108ffdc871f38431f049ab76f96787c27
Author: Shuang <lv...@alibaba-inc.com>
AuthorDate: Wed Nov 16 21:46:19 2022 +0800
[CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
---
.../apache/celeborn/client/ShuffleClientImpl.java | 56 +++++++++++++-
.../apache/celeborn/client/LifecycleManager.scala | 63 ++++++++++++++--
.../common/protocol/PartitionLocation.java | 15 +++-
.../celeborn/common/util/PackedPartitionId.java | 58 +++++++++++++++
common/src/main/proto/TransportMessages.proto | 9 +++
.../common/protocol/message/ControlMessages.scala | 18 +++++
.../common/protocol/PartitionLocationSuiteJ.java | 14 +++-
.../common/util/PackedPartitionIdSuiteJ.java | 53 +++++++++++++
.../celeborn/tests/client/ShuffleClientSuite.scala | 87 ++++++++++++++++++++++
9 files changed, 359 insertions(+), 14 deletions(-)
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 9cecb140..1283059d 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.client;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
+import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
@@ -27,6 +28,7 @@ import java.util.concurrent.TimeUnit;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
+import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
@@ -56,6 +58,7 @@ import org.apache.celeborn.common.rpc.RpcAddress;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.rpc.RpcEnv;
import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.PackedPartitionId;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
@@ -257,13 +260,58 @@ public class ShuffleClientImpl extends ShuffleClient {
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
String appId, int shuffleId, int numMappers, int numPartitions) {
+ return registerShuffleInternal(
+ shuffleId,
+ numMappers,
+ numMappers,
+ () ->
+ driverRssMetaService.askSync(
+ RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
+ ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+ }
+
+ @VisibleForTesting
+ public PartitionLocation registerMapPartitionTask(
+ String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
+ int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId);
+ logger.info(
+ "register mapPartitionTask, mapId: {}, attemptId: {}, partitionId: {}",
+ mapId,
+ attemptId,
+ partitionId);
+ if (attemptId == 0) {
+ return registerMapPartitionTaskWithFirstAttempt(
+ appId, shuffleId, numMappers, mapId, attemptId, partitionId);
+ }
+
+ // TODO
+ throw new UnsupportedOperationException("can not register shuffle task with attempt beyond 0");
+ }
+
+ private PartitionLocation registerMapPartitionTaskWithFirstAttempt(
+ String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) {
+ ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap =
+ registerShuffleInternal(
+ shuffleId,
+ numMappers,
+ numMappers,
+ () ->
+ driverRssMetaService.askSync(
+ RegisterMapPartitionTask$.MODULE$.apply(
+ appId, shuffleId, numMappers, mapId, attemptId, partitionId),
+ ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+ return partitionLocationMap.get(partitionId);
+ }
+
+ private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
+ int shuffleId,
+ int numMappers,
+ int numPartitions,
+ Callable<PbRegisterShuffleResponse> callable) {
int numRetries = registerShuffleMaxRetries;
while (numRetries > 0) {
try {
- PbRegisterShuffleResponse response =
- driverRssMetaService.askSync(
- RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
- ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class));
+ PbRegisterShuffleResponse response = callable.call();
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
ConcurrentHashMap<Integer, PartitionLocation> result = new ConcurrentHashMap<>();
diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 0dce8391..7a854f33 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
+import com.google.common.annotations.VisibleForTesting
import com.google.common.cache.{Cache, CacheBuilder}
import org.roaringbitmap.RoaringBitmap
@@ -53,7 +54,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
private val pushReplicateEnabled = conf.pushReplicateEnabled
private val partitionSplitThreshold = conf.partitionSplitThreshold
private val partitionSplitMode = conf.partitionSplitMode
- private val partitionType = conf.shufflePartitionType
+ // shuffle id -> partition type
+ private val shufflePartitionType = new ConcurrentHashMap[Int, PartitionType]()
private val rangeReadFilter = conf.shuffleRangeReadFilterEnabled
private val unregisterShuffleTime = new ConcurrentHashMap[Int, Long]()
private val stageEndTimeout = conf.pushStageEndTimeout
@@ -83,7 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
- private def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
+ @VisibleForTesting
+ def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
@@ -293,6 +296,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
rpcEnv.address.port
}
+ def getPartitionType(shuffleId: Int): PartitionType = {
+ shufflePartitionType.getOrDefault(shuffleId, conf.shufflePartitionType)
+ }
+
override def receive: PartialFunction[Any, Unit] = {
case RemoveExpiredShuffle =>
removeExpiredShuffle()
@@ -319,6 +326,22 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
s"$applicationId, $shuffleId, $numMappers, $numPartitions.")
handleRegisterShuffle(context, applicationId, shuffleId, numMappers, numPartitions)
+ case pb: PbRegisterMapPartitionTask =>
+ val applicationId = pb.getApplicationId
+ val shuffleId = pb.getShuffleId
+ val numMappers = pb.getNumMappers
+ val mapId = pb.getMapId
+ val attemptId = pb.getAttemptId
+ val partitionId = pb.getPartitionId
+ logDebug(s"Received Register map partition task request, " +
+ s"$applicationId, $shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
+ handleRegisterMapPartitionTask(
+ context,
+ applicationId,
+ shuffleId,
+ numMappers,
+ partitionId)
+
case pb: PbRevive =>
val applicationId = pb.getApplicationId
val shuffleId = pb.getShuffleId
@@ -379,6 +402,33 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
shuffleId: Int,
numMappers: Int,
numReducers: Int): Unit = {
+ handleOfferAndReserveSlots(context, applicationId, shuffleId, numMappers, numReducers)
+ }
+
+ private def handleRegisterMapPartitionTask(
+ context: RpcCallContext,
+ applicationId: String,
+ shuffleId: Int,
+ numMappers: Int,
+ partitionId: Int): Unit = {
+ shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
+ handleOfferAndReserveSlots(
+ context,
+ applicationId,
+ shuffleId,
+ numMappers,
+ numMappers,
+ partitionId)
+ }
+
+ private def handleOfferAndReserveSlots(
+ context: RpcCallContext,
+ applicationId: String,
+ shuffleId: Int,
+ numMappers: Int,
+ numReducers: Int,
+ partitionId: Int = -1): Unit = {
+ val partitionType = getPartitionType(shuffleId)
registeringShuffleRequest.synchronized {
if (registeringShuffleRequest.containsKey(shuffleId)) {
// If same request already exists in the registering request list for the same shuffle,
@@ -394,7 +444,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
.values()
.asScala
.flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
- .filter(_.getEpoch == 0)
+ .filter(p =>
+ (partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
.toArray
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
return
@@ -513,7 +564,9 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
- val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
+ val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).filter(p =>
+ partitionType ==
+ PartitionType.REDUCE || (partitionType == PartitionType.MAP && p.getId == partitionId)).toArray
reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
}
}
@@ -1123,7 +1176,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
slaveLocations,
partitionSplitThreshold,
partitionSplitMode,
- partitionType,
+ getPartitionType(shuffleId),
rangeReadFilter,
userIdentifier))
if (res.status.equals(StatusCode.SUCCESS)) {
diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
index 2967bd5c..dccf8783 100644
--- a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
+++ b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
@@ -22,6 +22,7 @@ import java.io.Serializable;
import org.roaringbitmap.RoaringBitmap;
import org.apache.celeborn.common.meta.WorkerInfo;
+import org.apache.celeborn.common.util.PackedPartitionId;
public class PartitionLocation implements Serializable {
public enum Mode {
@@ -277,9 +278,13 @@ public class PartitionLocation implements Serializable {
peerAddr = peer.hostAndPorts();
}
return "PartitionLocation["
- + "\n id-epoch:"
+ + "\n id(rawId-attemptId)-epoch:"
+ id
+ + "("
+ + getRawId()
+ "-"
+ + getAttemptId()
+ + ")-"
+ epoch
+ "\n host-rpcPort-pushPort-fetchPort-replicatePort:"
+ host
@@ -313,4 +318,12 @@ public class PartitionLocation implements Serializable {
public void setMapIdBitMap(RoaringBitmap mapIdBitMap) {
this.mapIdBitMap = mapIdBitMap;
}
+
+ public int getRawId() {
+ return PackedPartitionId.getRawPartitionId(id);
+ }
+
+ public int getAttemptId() {
+ return PackedPartitionId.getAttemptId(id);
+ }
}
diff --git a/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java b/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java
new file mode 100644
index 00000000..00163f5a
--- /dev/null
+++ b/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Pack for encode/decode id of partition Location for id of partitionLocation attemptId
+ * raw_partitionId <br>
+ * (upper 8 bits = attemptId) (lower 24 bits = raw id) <br>
+ * (0000 0000) (0000 0000 0000 0000 0000 0000)<br>
+ *
+ * @see org.apache.celeborn.common.protocol.PartitionLocation#id
+ */
+public class PackedPartitionId {
+
+ /** The maximum partition identifier that can be encoded. Note that partition ids start from 0. */
+ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+ /** The maximum partition attempt id that can be encoded. Note that attempt ids start from 0. */
+ static final int MAXIMUM_ATTEMPT_ID = (1 << 8) - 1; // 255
+
+ static final int MASK_INT_LOWER_24_BITS = (int) (1L << 24) - 1;
+
+ public static int packedPartitionId(int partitionRawId, int attemptId) {
+ Preconditions.checkArgument(
+ partitionRawId <= MAXIMUM_PARTITION_ID,
+ "packedPartitionId called with invalid partitionRawId: " + partitionRawId);
+ Preconditions.checkArgument(
+ attemptId <= MAXIMUM_ATTEMPT_ID,
+ "packedPartitionId called with invalid attemptId: " + attemptId);
+
+ return (attemptId << 24) | partitionRawId;
+ }
+
+ public static int getRawPartitionId(int packedPartitionId) {
+ return packedPartitionId & MASK_INT_LOWER_24_BITS;
+ }
+
+ public static int getAttemptId(int packedPartitionId) {
+ return packedPartitionId >>> 24;
+ }
+}
diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto
index c6ba510a..96ef54d9 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -155,6 +155,15 @@ message PbRegisterShuffle {
int32 numPartitions = 4;
}
+message PbRegisterMapPartitionTask {
+ string applicationId = 1;
+ int32 shuffleId = 2;
+ int32 numMappers = 3;
+ int32 mapId = 4;
+ int32 attemptId = 5;
+ int32 partitionId = 6;
+}
+
message PbRegisterShuffleResponse {
int32 status = 1;
repeated PbPartitionLocation partitionLocations = 2;
diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 52893e6b..b92a996c 100644
--- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -130,6 +130,24 @@ object ControlMessages extends Logging {
.build()
}
+ object RegisterMapPartitionTask {
+ def apply(
+ appId: String,
+ shuffleId: Int,
+ numMappers: Int,
+ mapId: Int,
+ attemptId: Int,
+ partitionId: Int): PbRegisterMapPartitionTask =
+ PbRegisterMapPartitionTask.newBuilder()
+ .setApplicationId(appId)
+ .setShuffleId(shuffleId)
+ .setNumMappers(numMappers)
+ .setMapId(mapId)
+ .setAttemptId(attemptId)
+ .setPartitionId(partitionId)
+ .build()
+ }
+
object RegisterShuffleResponse {
def apply(
status: StatusCode,
diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
index 1d120ab9..010f5afc 100644
--- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
+++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
@@ -20,6 +20,8 @@ package org.apache.celeborn.common.protocol;
import org.junit.Test;
import org.roaringbitmap.RoaringBitmap;
+import org.apache.celeborn.common.util.PackedPartitionId;
+
public class PartitionLocationSuiteJ {
private final int partitionId = 0;
@@ -183,9 +185,13 @@ public class PartitionLocationSuiteJ {
bitmap.add(1);
bitmap.add(2);
bitmap.add(3);
+
+ int attemptId = 10;
+ int rawPartitionId = 1000;
+ int newPartitionId = PackedPartitionId.packedPartitionId(rawPartitionId, attemptId);
PartitionLocation location3 =
new PartitionLocation(
- partitionId,
+ newPartitionId,
epoch,
host,
rpcPort,
@@ -199,7 +205,7 @@ public class PartitionLocationSuiteJ {
String exp1 =
"PartitionLocation[\n"
- + " id-epoch:0-0\n"
+ + " id(rawId-attemptId)-epoch:0(0-0)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(empty)\n"
@@ -207,7 +213,7 @@ public class PartitionLocationSuiteJ {
+ " mapIdBitMap:{}]";
String exp2 =
"PartitionLocation[\n"
- + " id-epoch:0-0\n"
+ + " id(rawId-attemptId)-epoch:0(0-0)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
@@ -215,7 +221,7 @@ public class PartitionLocationSuiteJ {
+ " mapIdBitMap:{}]";
String exp3 =
"PartitionLocation[\n"
- + " id-epoch:0-0\n"
+ + " id(rawId-attemptId)-epoch:167773160(1000-10)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
diff --git a/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java
new file mode 100644
index 00000000..8ed0bc30
--- /dev/null
+++ b/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class PackedPartitionIdSuiteJ {
+
+ @Test
+ public void testNormalPackedPartitionId() {
+ assertTest(0, 0);
+ assertTest(555, 1);
+ assertTest(888, 1);
+ assertTest(10001, 100);
+
+ // testUseMaxPartitionId or MaxAttemptId
+ assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, 11);
+ assertTest(100, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
+ assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testAttemptIdGreaterThanMaximumAttemptId() {
+ PackedPartitionId.packedPartitionId(0, PackedPartitionId.MAXIMUM_ATTEMPT_ID + 1);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testPartitionIdGreaterThanMaximumPartitionId() {
+ PackedPartitionId.packedPartitionId(PackedPartitionId.MAXIMUM_PARTITION_ID + 1, 1);
+ }
+
+ private void assertTest(int partitionRawId, int attemptId) {
+ int packedPartitionId = PackedPartitionId.packedPartitionId(partitionRawId, attemptId);
+ Assert.assertTrue(partitionRawId == PackedPartitionId.getRawPartitionId(packedPartitionId));
+ Assert.assertTrue(attemptId == PackedPartitionId.getAttemptId(packedPartitionId));
+ }
+}
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
new file mode 100644
index 00000000..e4b00d17
--- /dev/null
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.client
+
+import scala.collection.JavaConverters._
+import scala.language.implicitConversions
+
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClient, ShuffleClientImpl}
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.util.PackedPartitionId
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class ShuffleClientSuite extends AnyFunSuite with MiniClusterFeature
+ with BeforeAndAfterAll {
+ val masterPort = 19097
+ val APP = "app-1"
+ var shuffleClient: ShuffleClientImpl = _
+ var lifecycleManager: LifecycleManager = _
+
+ override def beforeAll(): Unit = {
+ val masterConf = Map(
+ "celeborn.master.host" -> "localhost",
+ "celeborn.master.port" -> masterPort.toString)
+ val workerConf = Map(
+ "celeborn.master.endpoints" -> s"localhost:$masterPort")
+ setUpMiniCluster(masterConf, workerConf)
+
+ val clientConf = new CelebornConf()
+ .set("celeborn.master.endpoints", s"localhost:$masterPort")
+ .set("celeborn.push.replicate.enabled", "true")
+ .set("celeborn.push.buffer.size", "256K")
+ lifecycleManager = new LifecycleManager(APP, clientConf)
+ shuffleClient = new ShuffleClientImpl(clientConf, UserIdentifier("mock", "mock"))
+ shuffleClient.setupMetaServiceRef(lifecycleManager.self)
+ }
+
+ test(s"test register map partition task with first attemptId") {
+ val shuffleId = 1
+ val numMappers = 8
+ val mapId = 1
+ val attemptId = 0
+ var location =
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
+ Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
+
+ // retry register
+ location = shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
+ Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
+
+ // another mapId
+ location =
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId)
+ Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId + 1, attemptId))
+
+ // offer and reserve all slots
+ val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
+ val count =
+ partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
+ Assert.assertEquals(count, numMappers)
+ }
+
+ override def afterAll(): Unit = {
+ // TODO refactor MiniCluster later
+ println("test done")
+ sys.exit(0)
+ }
+}