You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/01/31 10:53:40 UTC

[incubator-celeborn] branch main updated: [CELEBORN-201] Separate partitionLocationInfo in LifecycleManager and worker (#1149)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 7162be2f [CELEBORN-201] Separate partitionLocationInfo in LifecycleManager and worker (#1149)
7162be2f is described below

commit 7162be2fae752a90edb21bbb4c04d9bfb8d391b3
Author: Shuang <lv...@gmail.com>
AuthorDate: Tue Jan 31 18:53:36 2023 +0800

    [CELEBORN-201] Separate partitionLocationInfo in LifecycleManager and worker (#1149)
---
 .../celeborn/client/ChangePartitionManager.scala   |  4 +-
 .../apache/celeborn/client/LifecycleManager.scala  | 43 ++++------
 .../celeborn/client/commit/CommitHandler.scala     | 82 ++++++++++---------
 .../client/commit/MapPartitionCommitHandler.scala  | 10 +--
 .../commit/ReducePartitionCommitHandler.scala      |  8 +-
 .../celeborn/client/WithShuffleClientSuite.scala   |  4 +-
 .../common/meta/ShufflePartitionLocationInfo.scala | 94 +++++++++++++++++++++
 ...nfo.scala => WorkerPartitionLocationInfo.scala} | 95 +---------------------
 .../celeborn/common/util/CollectionUtils.java      | 41 ++++++++++
 .../service/deploy/worker/Controller.scala         |  4 +-
 .../service/deploy/worker/PushDataHandler.scala    |  4 +-
 .../celeborn/service/deploy/worker/Worker.scala    |  4 +-
 12 files changed, 216 insertions(+), 177 deletions(-)

diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index d7daebaa..17b74cf9 100644
--- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -293,9 +293,9 @@ class ChangePartitionManager(
           lifecycleManager.workerSnapshots(shuffleId).asScala
             .get(workInfo)
             .foreach { partitionLocationInfo =>
-              partitionLocationInfo.addMasterPartitions(shuffleId.toString, masterLocations)
+              partitionLocationInfo.addMasterPartitions(masterLocations)
               lifecycleManager.updateLatestPartitionLocations(shuffleId, masterLocations)
-              partitionLocationInfo.addSlavePartitions(shuffleId.toString, slaveLocations)
+              partitionLocationInfo.addSlavePartitions(slaveLocations)
             }
           // partition location can be null when call reserveSlotsWithRetry().
           val locations = (masterLocations.asScala ++ slaveLocations.asScala.map(_.getPeer))
diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 1f1fda0b..8e542755 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -17,10 +17,9 @@
 
 package org.apache.celeborn.client
 
-import java.nio.ByteBuffer
 import java.util
 import java.util.{function, List => JList}
-import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledFuture, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -28,18 +27,17 @@ import scala.util.Random
 
 import com.google.common.annotations.VisibleForTesting
 
-import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers, ShuffleFileGroups}
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.haclient.RssHARetryClient
 import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
 import org.apache.celeborn.common.protocol._
 import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
 import org.apache.celeborn.common.protocol.message.ControlMessages._
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc._
-import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
 import org.apache.celeborn.common.util.{PbSerDeUtils, ThreadUtils, Utils}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
@@ -48,7 +46,7 @@ object LifecycleManager {
   type ShuffleFileGroups =
     ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]]
   type ShuffleAllocatedWorkers =
-    ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, PartitionLocationInfo]]
+    ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, ShufflePartitionLocationInfo]]
   type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
 }
 
@@ -77,7 +75,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
   private val userIdentifier: UserIdentifier = IdentityProvider.instantiate(conf).provide()
 
   @VisibleForTesting
-  def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
+  def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
     shuffleAllocatedWorkers.get(shuffleId)
 
   val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
@@ -350,7 +348,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
           val initialLocs = workerSnapshots(shuffleId)
             .values()
             .asScala
-            .flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
+            .flatMap(_.getAllMasterLocationsWithMinEpoch().asScala)
             .filter(p =>
               (partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
             .toArray
@@ -470,12 +468,12 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       logInfo(s"ReserveSlots for ${Utils.makeShuffleKey(applicationId, shuffleId)} success!")
       logDebug(s"Allocated Slots: $slots")
       // Forth, register shuffle success, update status
-      val allocatedWorkers = new ConcurrentHashMap[WorkerInfo, PartitionLocationInfo]()
+      val allocatedWorkers = new ConcurrentHashMap[WorkerInfo, ShufflePartitionLocationInfo]()
       slots.asScala.foreach { case (workerInfo, (masterLocations, slaveLocations)) =>
-        val partitionLocationInfo = new PartitionLocationInfo()
-        partitionLocationInfo.addMasterPartitions(shuffleId.toString, masterLocations)
+        val partitionLocationInfo = new ShufflePartitionLocationInfo()
+        partitionLocationInfo.addMasterPartitions(masterLocations)
         updateLatestPartitionLocations(shuffleId, masterLocations)
-        partitionLocationInfo.addSlavePartitions(shuffleId.toString, slaveLocations)
+        partitionLocationInfo.addSlavePartitions(slaveLocations)
         allocatedWorkers.put(workerInfo, partitionLocationInfo)
       }
       shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
@@ -631,10 +629,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
 
     if (commitManager.tryFinalCommit(shuffleId)) {
       // release resources and clear worker info
-      workerSnapshots(shuffleId).asScala.foreach { case (_, partitionLocationInfo) =>
-        partitionLocationInfo.removeMasterPartitions(shuffleId.toString)
-        partitionLocationInfo.removeSlavePartitions(shuffleId.toString)
-      }
+      shuffleAllocatedWorkers.remove(shuffleId)
+
       requestReleaseSlots(
         rssHARetryClient,
         ReleaseSlots(applicationId, shuffleId, List.empty.asJava, List.empty.asJava))
@@ -691,13 +687,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       }
     }
 
-    if (partitionExists(shuffleId)) {
+    if (shuffleResourceExists(shuffleId)) {
       logWarning(s"Partition exists for shuffle $shuffleId, " +
         "maybe caused by task rerun or speculative.")
-      workerSnapshots(shuffleId).asScala.foreach { case (_, partitionLocationInfo) =>
-        partitionLocationInfo.removeMasterPartitions(shuffleId.toString)
-        partitionLocationInfo.removeSlavePartitions(shuffleId.toString)
-      }
+      shuffleAllocatedWorkers.remove(shuffleId)
       requestReleaseSlots(
         rssHARetryClient,
         ReleaseSlots(appId, shuffleId, List.empty.asJava, List.empty.asJava))
@@ -1238,13 +1231,9 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
     }
   }
 
-  private def partitionExists(shuffleId: Int): Boolean = {
+  private def shuffleResourceExists(shuffleId: Int): Boolean = {
     val workers = workerSnapshots(shuffleId)
-    if (workers == null || workers.isEmpty) {
-      false
-    } else {
-      workers.values().asScala.exists(_.containsShuffle(shuffleId.toString))
-    }
+    workers != null && !workers.isEmpty
   }
 
   // Initialize at the end of LifecycleManager construction.
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index e9ace39c..16d78106 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -32,13 +32,13 @@ import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, Shu
 import org.apache.celeborn.client.ShuffleCommittedInfo
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 import org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, CommitFilesResponse, GetReducerFileGroupResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.{RpcCallContext, RpcEndpointRef}
 import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
-import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+import org.apache.celeborn.common.util.{CollectionUtils, ThreadUtils, Utils}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
 
@@ -211,54 +211,58 @@ abstract class CommitHandler(
 
   def parallelCommitFiles(
       shuffleId: Int,
-      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+      allocatedWorkers: util.Map[WorkerInfo, ShufflePartitionLocationInfo],
       partitionIdOpt: Option[Int] = None): CommitResult = {
     val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
     val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
     val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
     val commitFilesFailedWorkers = new ShuffleFailedWorkers()
+
+    if (CollectionUtils.isEmpty(allocatedWorkers)) {
+      return CommitResult(masterPartMap, slavePartMap, commitFilesFailedWorkers)
+    }
+
     val commitFileStartTime = System.nanoTime()
     val parallelism = Math.min(allocatedWorkers.size(), conf.rpcMaxParallelism)
     ThreadUtils.parmap(
       allocatedWorkers.asScala.to,
       "CommitFiles",
       parallelism) { case (worker, partitionLocationInfo) =>
-      if (partitionLocationInfo.containsShuffle(shuffleId.toString)) {
-        val masterParts =
-          partitionLocationInfo.getMasterLocations(shuffleId.toString, partitionIdOpt)
-        val slaveParts = partitionLocationInfo.getSlaveLocations(shuffleId.toString, partitionIdOpt)
-        masterParts.asScala.foreach { p =>
-          val partition = new PartitionLocation(p)
-          partition.setFetchPort(worker.fetchPort)
-          partition.setPeer(null)
-          masterPartMap.put(partition.getUniqueId, partition)
-        }
-        slaveParts.asScala.foreach { p =>
-          val partition = new PartitionLocation(p)
-          partition.setFetchPort(worker.fetchPort)
-          partition.setPeer(null)
-          slavePartMap.put(partition.getUniqueId, partition)
-        }
-
-        val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
-          (
-            masterParts.asScala
-              .filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
-              .map(_.getUniqueId).asJava,
-            slaveParts.asScala
-              .filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
-              .map(_.getUniqueId).asJava)
-        }
+      val masterParts =
+        partitionLocationInfo.getMasterPartitions(partitionIdOpt)
+      val slaveParts = partitionLocationInfo.getSlavePartitions(partitionIdOpt)
+      masterParts.asScala.foreach { p =>
+        val partition = new PartitionLocation(p)
+        partition.setFetchPort(worker.fetchPort)
+        partition.setPeer(null)
+        masterPartMap.put(partition.getUniqueId, partition)
+      }
+      slaveParts.asScala.foreach { p =>
+        val partition = new PartitionLocation(p)
+        partition.setFetchPort(worker.fetchPort)
+        partition.setPeer(null)
+        slavePartMap.put(partition.getUniqueId, partition)
+      }
 
-        commitFiles(
-          appId,
-          shuffleId,
-          shuffleCommittedInfo,
-          worker,
-          masterIds,
-          slaveIds,
-          commitFilesFailedWorkers)
+      val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
+        (
+          masterParts.asScala
+            .filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
+            .map(_.getUniqueId).asJava,
+          slaveParts.asScala
+            .filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
+            .map(_.getUniqueId).asJava)
       }
+
+      commitFiles(
+        appId,
+        shuffleId,
+        shuffleCommittedInfo,
+        worker,
+        masterIds,
+        slaveIds,
+        commitFilesFailedWorkers)
+
     }
 
     logInfo(s"Shuffle $shuffleId " +
@@ -277,6 +281,10 @@ abstract class CommitHandler(
       slaveIds: util.List[String],
       commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
 
+    if (CollectionUtils.isEmpty(masterIds) && CollectionUtils.isEmpty(slaveIds)) {
+      return
+    }
+
     val res =
       if (!testRetryCommitFiles) {
         val commitFiles = CommitFiles(
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index a4a9b5cb..74498f52 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -26,11 +26,11 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
-import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers, ShuffleFileGroups}
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
 import org.apache.celeborn.client.ShuffleCommittedInfo
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
@@ -123,7 +123,7 @@ class MapPartitionCommitHandler(
 
   private def handleFinalPartitionCommitFiles(
       shuffleId: Int,
-      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+      allocatedWorkers: util.Map[WorkerInfo, ShufflePartitionLocationInfo],
       partitionId: Int): (Boolean, ShuffleFailedWorkers) = {
     val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
     // commit files
@@ -190,7 +190,7 @@ class MapPartitionCommitHandler(
     inProcessingPartitionIds.add(partitionId)
 
     val partitionAllocatedWorkers = allocatedWorkers.get(shuffleId).asScala.filter(p =>
-      p._2.containsPartition(shuffleId.toString, partitionId)).asJava
+      p._2.containsPartition(partitionId)).asJava
 
     var dataCommitSuccess = true
     if (!partitionAllocatedWorkers.isEmpty) {
@@ -205,7 +205,7 @@ class MapPartitionCommitHandler(
 
     // release resources and clear related info
     partitionAllocatedWorkers.asScala.foreach { case (_, partitionLocationInfo) =>
-      partitionLocationInfo.removeRelatedPartitions(shuffleId.toString, partitionId)
+      partitionLocationInfo.removePartitions(partitionId)
     }
 
     inProcessingPartitionIds.remove(partitionId)
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 99bd7deb..9d6423d7 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -24,14 +24,14 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
-import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers, ShuffleFileGroups}
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
 import org.apache.celeborn.client.ShuffleCommittedInfo
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
-import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
 import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
+import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcCallContext
 
 /**
@@ -126,7 +126,7 @@ class ReducePartitionCommitHandler(
 
   private def handleFinalCommitFiles(
       shuffleId: Int,
-      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo])
+      allocatedWorkers: util.Map[WorkerInfo, ShufflePartitionLocationInfo])
       : (Boolean, ShuffleFailedWorkers) = {
     val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
 
diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 618d4a38..a8dcb4e7 100644
--- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -59,7 +59,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
     // check all allocated slots
     var partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
     var count =
-      partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
+      partitionLocationInfos.map(r => r.getMasterPartitions().size()).sum
     Assert.assertEquals(count, numMappers)
 
     // another mapId
@@ -78,7 +78,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
     partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
     logInfo(partitionLocationInfos.toString())
     count =
-      partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
+      partitionLocationInfos.map(r => r.getMasterPartitions().size()).sum
     Assert.assertEquals(count, numMappers + 1)
   }
 
diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
new file mode 100644
index 00000000..21427d4e
--- /dev/null
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.meta
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.celeborn.common.protocol.PartitionLocation
+
+class ShufflePartitionLocationInfo {
+
+  type PartitionInfo = ConcurrentHashMap[Int, util.List[PartitionLocation]]
+  private val masterPartitionLocations = new PartitionInfo
+  private val slavePartitionLocations = new PartitionInfo
+
+  def addMasterPartitions(masterLocations: util.List[PartitionLocation]) = {
+    addPartitions(masterPartitionLocations, masterLocations)
+  }
+
+  def addSlavePartitions(slaveLocations: util.List[PartitionLocation]) = {
+    addPartitions(slavePartitionLocations, slaveLocations)
+  }
+
+  def getMasterPartitions(partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+    getPartitions(masterPartitionLocations, partitionIdOpt)
+  }
+
+  def getSlavePartitions(partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+    getPartitions(slavePartitionLocations, partitionIdOpt)
+  }
+
+  def containsPartition(partitionId: Int): Boolean = {
+    masterPartitionLocations.containsKey(partitionId) ||
+    slavePartitionLocations.containsKey(partitionId)
+  }
+
+  def removePartitions(partitionId: Int): Unit = {
+    masterPartitionLocations.remove(partitionId)
+    slavePartitionLocations.remove(partitionId)
+  }
+
+  def getAllMasterLocationsWithMinEpoch(): util.List[PartitionLocation] = {
+    def order(a: Int, b: Int): Boolean = a < b
+
+    masterPartitionLocations.values().asScala.map { list =>
+      var loc = list.get(0)
+      1 until list.size() foreach (ind => {
+        if (order(list.get(ind).getEpoch, loc.getEpoch)) {
+          loc = list.get(ind)
+        }
+      })
+      loc
+    }.toList.asJava
+  }
+
+  private def addPartitions(
+      partitionInfo: PartitionInfo,
+      locations: util.List[PartitionLocation]): Unit = {
+    if (locations != null && locations.size() > 0) {
+      locations.asScala.foreach { loc =>
+        partitionInfo.putIfAbsent(loc.getId, new util.ArrayList)
+        val locations = partitionInfo.get(loc.getId)
+        locations.add(loc)
+      }
+    }
+  }
+
+  private def getPartitions(
+      partitionInfo: PartitionInfo,
+      partitionIdOpt: Option[Int]): util.List[PartitionLocation] = {
+    partitionIdOpt match {
+      case Some(partitionId) =>
+        partitionInfo.getOrDefault(partitionId, new util.ArrayList)
+      case _ => partitionInfo.values().asScala.flatMap(_.asScala).toList.asJava
+    }
+  }
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
similarity index 72%
rename from common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
rename to common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
index e23d2ece..b716d65a 100644
--- a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.protocol.PartitionLocation
 
-class PartitionLocationInfo extends Logging {
+class WorkerPartitionLocationInfo extends Logging {
 
   // key: ShuffleKey, values: (partitionId -> partition locations)
   type PartitionInfo = util.HashMap[String, util.Map[Int, util.List[PartitionLocation]]]
@@ -46,23 +46,6 @@ class PartitionLocationInfo extends Logging {
     slavePartitionLocations.containsKey(shuffleKey)
   }
 
-  def containsPartition(shuffleKey: String, partitionId: Int): Boolean = this
-    .synchronized {
-      val contain = masterPartitionLocations.containsKey(
-        shuffleKey) && masterPartitionLocations.get(shuffleKey).containsKey(partitionId)
-      contain || (slavePartitionLocations.containsKey(shuffleKey) && slavePartitionLocations.get(
-        shuffleKey)
-        .containsKey(partitionId))
-    }
-
-  def addMasterPartition(shuffleKey: String, location: PartitionLocation): Int = {
-    addPartition(shuffleKey, location, masterPartitionLocations)
-  }
-
-  def addSlavePartition(shuffleKey: String, location: PartitionLocation): Int = {
-    addPartition(shuffleKey, location, slavePartitionLocations)
-  }
-
   def addMasterPartitions(
       shuffleKey: String,
       locations: util.List[PartitionLocation]): Unit = {
@@ -149,70 +132,6 @@ class PartitionLocationInfo extends Logging {
     removePartitions(shuffleKey, uniqueIds, slavePartitionLocations)
   }
 
-  def getAllMasterLocationsWithMinEpoch(shuffleKey: String): util.List[PartitionLocation] =
-    this.synchronized {
-      getAllMasterLocationsWithExtremeEpoch(shuffleKey, (a, b) => a < b)
-    }
-
-  def getAllMasterLocationsWithExtremeEpoch(
-      shuffleKey: String,
-      order: (Int, Int) => Boolean): util.List[PartitionLocation] = this.synchronized {
-    if (masterPartitionLocations.containsKey(shuffleKey)) {
-      masterPartitionLocations.get(shuffleKey)
-        .values()
-        .asScala
-        .map { list =>
-          var loc = list.get(0)
-          1 until list.size() foreach (ind => {
-            if (order(list.get(ind).getEpoch, loc.getEpoch)) {
-              loc = list.get(ind)
-            }
-          })
-          loc
-        }.toList.asJava
-    } else {
-      new util.ArrayList[PartitionLocation]()
-    }
-  }
-
-  def getLocationWithMaxEpoch(
-      shuffleKey: String,
-      partitionId: Int): Option[PartitionLocation] = this.synchronized {
-    if (!masterPartitionLocations.containsKey(shuffleKey) ||
-      !masterPartitionLocations.get(shuffleKey).containsKey(partitionId)) {
-      return None
-    }
-    val locations = masterPartitionLocations.get(shuffleKey).get(partitionId)
-    if (locations == null || locations.size() == 0) {
-      return None
-    }
-    var currentEpoch = -1
-    var currentPartition: PartitionLocation = null
-    locations.asScala.foreach(loc => {
-      if (loc.getEpoch > currentEpoch) {
-        currentEpoch = loc.getEpoch
-        currentPartition = loc
-      }
-    })
-    Some(currentPartition)
-  }
-
-  private def addPartition(
-      shuffleKey: String,
-      location: PartitionLocation,
-      partitionInfo: PartitionInfo): Int = this.synchronized {
-    if (location != null) {
-      partitionInfo.putIfAbsent(shuffleKey, new util.HashMap[Int, util.List[PartitionLocation]]())
-      val reduceLocMap = partitionInfo.get(shuffleKey)
-      reduceLocMap.putIfAbsent(location.getId, new util.ArrayList[PartitionLocation]())
-      val locations = reduceLocMap.get(location.getId)
-      locations.add(location)
-      1
-    } else {
-      0
-    }
-  }
-
   private def addPartitions(
       shuffleKey: String,
       locations: util.List[PartitionLocation],
@@ -228,18 +147,6 @@ class PartitionLocationInfo extends Logging {
     }
   }
 
-  def removeRelatedPartitions(
-      shuffleKey: String,
-      partitionId: Int): Unit = this
-    .synchronized {
-      if (masterPartitionLocations.containsKey(shuffleKey)) {
-        masterPartitionLocations.get(shuffleKey).remove(partitionId)
-      }
-      if (slavePartitionLocations.containsKey(shuffleKey)) {
-        slavePartitionLocations.get(shuffleKey).remove(partitionId)
-      }
-    }
-
   /**
    * @param shuffleKey
    * @param uniqueIds
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/CollectionUtils.java b/common/src/main/scala/org/apache/celeborn/common/util/CollectionUtils.java
new file mode 100644
index 00000000..53316d56
--- /dev/null
+++ b/common/src/main/scala/org/apache/celeborn/common/util/CollectionUtils.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util;
+
+import java.util.Collection;
+import java.util.Map;
+
+public class CollectionUtils {
+
+  public static boolean isEmpty(Collection collection) {
+    return collection == null || collection.isEmpty();
+  }
+
+  public static boolean isNotEmpty(Collection collection) {
+    return !isEmpty(collection);
+  }
+
+  public static boolean isEmpty(Map map) {
+    return map == null || map.isEmpty();
+  }
+
+  public static boolean isNotEmpty(Map map) {
+    return !isEmpty(map);
+  }
+
+}
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index f4f311cc..24e2859d 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -31,7 +31,7 @@ import org.roaringbitmap.RoaringBitmap
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo}
 import org.apache.celeborn.common.metrics.MetricsSystem
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo}
 import org.apache.celeborn.common.protocol.message.ControlMessages._
@@ -53,7 +53,7 @@ private[deploy] class Controller(
   var shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, CommitInfo]] = _
   var shufflePartitionType: ConcurrentHashMap[String, PartitionType] = _
   var workerInfo: WorkerInfo = _
-  var partitionLocationInfo: PartitionLocationInfo = _
+  var partitionLocationInfo: WorkerPartitionLocationInfo = _
   var timer: HashedWheelTimer = _
   var commitThreadPool: ThreadPoolExecutor = _
   var asyncReplyPool: ScheduledExecutorService = _
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index ccc254c5..a971e785 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -27,7 +27,7 @@ import io.netty.buffer.ByteBuf
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.AlreadyClosedException
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo}
 import org.apache.celeborn.common.metrics.source.RPCSource
 import org.apache.celeborn.common.network.buffer.{NettyManagedBuffer, NioManagedBuffer}
 import org.apache.celeborn.common.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
@@ -44,7 +44,7 @@ class PushDataHandler extends BaseMessageHandler with Logging {
 
   var workerSource: WorkerSource = _
   var rpcSource: RPCSource = _
-  var partitionLocationInfo: PartitionLocationInfo = _
+  var partitionLocationInfo: WorkerPartitionLocationInfo = _
   var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
   var shufflePartitionType: ConcurrentHashMap[String, PartitionType] = _
   var replicateThreadPool: ThreadPoolExecutor = _
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index 1b0e6cfd..3e61a9d9 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -35,7 +35,7 @@ import org.apache.celeborn.common.exception.CelebornException
 import org.apache.celeborn.common.haclient.RssHARetryClient
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{DiskInfo, PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.{DiskInfo, WorkerInfo, WorkerPartitionLocationInfo}
 import org.apache.celeborn.common.metrics.MetricsSystem
 import org.apache.celeborn.common.metrics.source.{JVMCPUSource, JVMSource, RPCSource}
 import org.apache.celeborn.common.network.TransportContext
@@ -180,7 +180,7 @@ private[celeborn] class Worker(
   val registered = new AtomicBoolean(false)
   val shuffleMapperAttempts = new ConcurrentHashMap[String, AtomicIntegerArray]()
   val shufflePartitionType = new ConcurrentHashMap[String, PartitionType]
-  val partitionLocationInfo = new PartitionLocationInfo
+  val partitionLocationInfo = new WorkerPartitionLocationInfo
 
   val shuffleCommitInfos = new ConcurrentHashMap[String, ConcurrentHashMap[Long, CommitInfo]]()