You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2022/11/29 05:09:06 UTC

[incubator-celeborn] branch main updated: [CELEBORN-76] Support batch commit hard split partition before stage end

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new d26e7320 [CELEBORN-76] Support batch commit hard split partition before stage end
d26e7320 is described below

commit d26e73209b0e56c2b0ab9251ffda7116f2d02840
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Tue Nov 29 13:09:01 2022 +0800

    [CELEBORN-76] Support batch commit hard split partition before stage end
---
 .../apache/celeborn/client/LifecycleManager.scala  | 412 +++++++++++++++------
 .../org/apache/celeborn/common/CelebornConf.scala  |  28 ++
 docs/configuration/client.md                       |   3 +
 .../service/deploy/worker/PushDataHandler.scala    |  42 ++-
 4 files changed, 367 insertions(+), 118 deletions(-)

diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index bd2887d7..bed13215 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
 import java.util
 import java.util.{function, List => JList}
 import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit}
-import java.util.concurrent.atomic.{AtomicLong, LongAdder}
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -127,6 +127,27 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
   // shuffleId -> set of partition id
   private val inBatchPartitions = new ConcurrentHashMap[Int, util.Set[Integer]]()
 
+  case class CommitPartitionRequest(
+      applicationId: String,
+      shuffleId: Int,
+      partition: PartitionLocation)
+
+  case class ShuffleCommittedInfo(
+      committedMasterIds: util.List[String],
+      committedSlaveIds: util.List[String],
+      failedMasterPartitionIds: ConcurrentHashMap[String, WorkerInfo],
+      failedSlavePartitionIds: ConcurrentHashMap[String, WorkerInfo],
+      committedMasterStorageInfos: ConcurrentHashMap[String, StorageInfo],
+      committedSlaveStorageInfos: ConcurrentHashMap[String, StorageInfo],
+      committedMapIdBitmap: ConcurrentHashMap[String, RoaringBitmap],
+      currentShuffleFileCount: LongAdder,
+      commitPartitionRequests: util.Set[CommitPartitionRequest],
+      handledCommitPartitionRequests: util.Set[PartitionLocation],
+      inFlightCommitRequest: AtomicInteger)
+
+  // shuffle id -> ShuffleCommittedInfo
+  private val committedPartitionInfo = new ConcurrentHashMap[Int, ShuffleCommittedInfo]()
+
   // register shuffle request waiting for response
   private val registeringShuffleRequest =
     new ConcurrentHashMap[Int, util.Set[RegisterCallContext]]()
@@ -162,6 +183,20 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       None
     }
 
+  private val batchHandleCommitPartitionEnabled = conf.batchHandleCommitPartitionEnabled
+  private val batchHandleCommitPartitionExecutors = ThreadUtils.newDaemonCachedThreadPool(
+    "rss-lifecycle-manager-commit-partition-executor",
+    conf.batchHandleCommitPartitionNumThreads)
+  private val batchHandleCommitPartitionRequestInterval =
+    conf.batchHandleCommitPartitionRequestInterval
+  private val batchHandleCommitPartitionSchedulerThread: Option[ScheduledExecutorService] =
+    if (batchHandleCommitPartitionEnabled) {
+      Some(ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+        "rss-lifecycle-manager-commit-partition-scheduler"))
+    } else {
+      None
+    }
+
   // init driver rss meta rpc service
   override val rpcEnv: RpcEnv = RpcEnv.create(
     RpcNameConstants.RSS_METASERVICE_SYS,
@@ -249,6 +284,111 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
         batchHandleChangePartitionRequestInterval,
         TimeUnit.MILLISECONDS)
     }
+
+    batchHandleCommitPartitionSchedulerThread.foreach {
+      _.scheduleAtFixedRate(
+        new Runnable {
+          override def run(): Unit = {
+            committedPartitionInfo.asScala.foreach { case (shuffleId, shuffleCommittedInfo) =>
+              batchHandleCommitPartitionExecutors.submit {
+                new Runnable {
+                  override def run(): Unit = {
+                    if (inProcessStageEndShuffleSet.contains(shuffleId) ||
+                      stageEndShuffleSet.contains(shuffleId)) {
+                      logWarning(s"Shuffle $shuffleId ended or during processing stage end.")
+                      shuffleCommittedInfo.synchronized {
+                        shuffleCommittedInfo.commitPartitionRequests.clear()
+                      }
+                    } else {
+                      val currentBatch = shuffleCommittedInfo.synchronized {
+                        val batch = new util.HashSet[CommitPartitionRequest]()
+                        batch.addAll(shuffleCommittedInfo.commitPartitionRequests)
+                        val currentBatch = batch.asScala.filterNot { request =>
+                          shuffleCommittedInfo.handledCommitPartitionRequests
+                            .contains(request.partition)
+                        }
+                        shuffleCommittedInfo.commitPartitionRequests.clear()
+                        currentBatch.foreach { commitPartitionRequest =>
+                          shuffleCommittedInfo.handledCommitPartitionRequests
+                            .add(commitPartitionRequest.partition)
+                          if (commitPartitionRequest.partition.getPeer != null) {
+                            shuffleCommittedInfo.handledCommitPartitionRequests
+                              .add(commitPartitionRequest.partition.getPeer)
+                          }
+                        }
+                        // When running to here, if handleStageEnd got lock first and commitFiles,
+                        // then this batch get this lock, commitPartitionRequests may contains
+                        // partitions which are already committed by stageEnd process.
+                        // But inProcessStageEndShuffleSet should have contain this shuffle id,
+                        // can directly return.
+                        if (inProcessStageEndShuffleSet.contains(shuffleId) ||
+                          stageEndShuffleSet.contains(shuffleId)) {
+                          logWarning(s"Shuffle $shuffleId ended or during processing stage end.")
+                          Seq.empty
+                        } else {
+                          currentBatch
+                        }
+                      }
+                      if (currentBatch.nonEmpty) {
+                        logWarning(s"Commit current batch HARD_SPLIT partitions for $shuffleId: " +
+                          s"${currentBatch.map(_.partition.getUniqueId).mkString("[", ",", "]")}")
+                        val workerToRequests = currentBatch.flatMap { request =>
+                          if (request.partition.getPeer != null) {
+                            Seq(request.partition, request.partition.getPeer)
+                          } else {
+                            Seq(request.partition)
+                          }
+                        }.groupBy(_.getWorker)
+                        val commitFilesFailedWorkers =
+                          new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
+                        val parallelism = workerToRequests.size
+                        ThreadUtils.parmap(
+                          workerToRequests.to,
+                          "CommitFiles",
+                          parallelism) {
+                          case (worker, requests) =>
+                            val workerInfo =
+                              shuffleAllocatedWorkers
+                                .get(shuffleId)
+                                .asScala
+                                .find(_._1.equals(worker))
+                                .get
+                                ._1
+                            val mastersIds =
+                              requests
+                                .filter(_.getMode == PartitionLocation.Mode.MASTER)
+                                .map(_.getUniqueId)
+                                .toList
+                                .asJava
+                            val slaveIds =
+                              requests
+                                .filter(_.getMode == PartitionLocation.Mode.SLAVE)
+                                .map(_.getUniqueId)
+                                .toList
+                                .asJava
+
+                            commitFiles(
+                              appId,
+                              shuffleId,
+                              shuffleCommittedInfo,
+                              workerInfo,
+                              mastersIds,
+                              slaveIds,
+                              commitFilesFailedWorkers)
+                        }
+                        recordWorkerFailure(commitFilesFailedWorkers)
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        },
+        0,
+        batchHandleCommitPartitionRequestInterval,
+        TimeUnit.MILLISECONDS)
+    }
   }
 
   override def onStart(): Unit = {
@@ -576,6 +716,20 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
 
       // Fifth, reply the allocated partition location to ShuffleClient.
       logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
+      committedPartitionInfo.put(
+        shuffleId,
+        ShuffleCommittedInfo(
+          new util.ArrayList[String](),
+          new util.ArrayList[String](),
+          new ConcurrentHashMap[String, WorkerInfo](),
+          new ConcurrentHashMap[String, WorkerInfo](),
+          new ConcurrentHashMap[String, StorageInfo](),
+          new ConcurrentHashMap[String, StorageInfo](),
+          new ConcurrentHashMap[String, RoaringBitmap](),
+          new LongAdder,
+          new util.HashSet[CommitPartitionRequest](),
+          new util.HashSet[PartitionLocation](),
+          new AtomicInteger()))
       val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
       reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
     }
@@ -692,6 +846,16 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
     // check if there exists request for the partition, if do just register
     val requests = changePartitionRequests.computeIfAbsent(shuffleId, rpcContextRegisterFunc)
     inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
+
+    // handle hard split
+    if (batchHandleCommitPartitionEnabled && cause.isDefined && cause.get == StatusCode.HARD_SPLIT) {
+      val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+      shuffleCommittedInfo.synchronized {
+        shuffleCommittedInfo.commitPartitionRequests
+          .add(CommitPartitionRequest(applicationId, shuffleId, oldPartition))
+      }
+    }
+
     requests.synchronized {
       if (requests.containsKey(partitionId)) {
         requests.get(partitionId).add(changePartition)
@@ -951,18 +1115,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
     // ask allLocations workers holding partitions to commit files
     val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
     val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
-    val committedMasterIds = ConcurrentHashMap.newKeySet[String]()
-    val committedSlaveIds = ConcurrentHashMap.newKeySet[String]()
-    val committedMasterStorageInfos = new ConcurrentHashMap[String, StorageInfo]()
-    val committedSlaveStorageInfos = new ConcurrentHashMap[String, StorageInfo]()
-    val committedMapIdBitmap = new ConcurrentHashMap[String, RoaringBitmap]()
-    val failedMasterPartitionIds = new ConcurrentHashMap[String, WorkerInfo]()
-    val failedSlavePartitionIds = new ConcurrentHashMap[String, WorkerInfo]()
 
     val allocatedWorkers = shuffleAllocatedWorkers.get(shuffleId)
+    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
     val commitFilesFailedWorkers = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
-
-    val currentShuffleFileCount = new LongAdder
     val commitFileStartTime = System.nanoTime()
 
     val parallelism = Math.min(workerSnapshots(shuffleId).size(), conf.rpcMaxParallelism)
@@ -986,93 +1142,35 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
           slavePartMap.put(partition.getUniqueId, partition)
         }
 
-        val masterIds = masterParts.asScala.map(_.getUniqueId).asJava
-        val slaveIds = slaveParts.asScala.map(_.getUniqueId).asJava
-
-        val res =
-          if (!testRetryCommitFiles) {
-            val commitFiles = CommitFiles(
-              applicationId,
-              shuffleId,
-              masterIds,
-              slaveIds,
-              shuffleMapperAttempts.get(shuffleId),
-              commitEpoch.incrementAndGet())
-            val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-            res.status match {
-              case StatusCode.SUCCESS => // do nothing
-              case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.FAILED =>
-                logDebug(s"Request $commitFiles return ${res.status} for " +
-                  s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
-                commitFilesFailedWorkers.put(worker, (res.status, System.currentTimeMillis()))
-              case _ => // won't happen
-            }
-            res
-          } else {
-            // for test
-            val commitFiles1 = CommitFiles(
-              applicationId,
-              shuffleId,
-              masterIds.subList(0, masterIds.size() / 2),
-              slaveIds.subList(0, slaveIds.size() / 2),
-              shuffleMapperAttempts.get(shuffleId),
-              commitEpoch.incrementAndGet())
-            val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
-
-            val commitFiles = CommitFiles(
-              applicationId,
-              shuffleId,
-              masterIds.subList(masterIds.size() / 2, masterIds.size()),
-              slaveIds.subList(slaveIds.size() / 2, slaveIds.size()),
-              shuffleMapperAttempts.get(shuffleId),
-              commitEpoch.incrementAndGet())
-            val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-            res1.committedMasterStorageInfos.putAll(res2.committedMasterStorageInfos)
-            res1.committedSlaveStorageInfos.putAll(res2.committedSlaveStorageInfos)
-            res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
-            CommitFilesResponse(
-              status = if (res1.status == StatusCode.SUCCESS) res2.status else res1.status,
-              (res1.committedMasterIds.asScala ++ res2.committedMasterIds.asScala).toList.asJava,
-              (res1.committedSlaveIds.asScala ++ res1.committedSlaveIds.asScala).toList.asJava,
-              (res1.failedMasterIds.asScala ++ res1.failedMasterIds.asScala).toList.asJava,
-              (res1.failedSlaveIds.asScala ++ res2.failedSlaveIds.asScala).toList.asJava,
-              res1.committedMasterStorageInfos,
-              res1.committedSlaveStorageInfos,
-              res1.committedMapIdBitMap,
-              res1.totalWritten + res2.totalWritten,
-              res1.fileCount + res2.fileCount)
-          }
-
-        // record committed partitionIds
-        committedMasterIds.addAll(res.committedMasterIds)
-        committedSlaveIds.addAll(res.committedSlaveIds)
-
-        // record committed partitions storage hint and disk hint
-        committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
-        committedSlaveStorageInfos.putAll(res.committedSlaveStorageInfos)
-
-        // record failed partitions
-        failedMasterPartitionIds.putAll(res.failedMasterIds.asScala.map((_, worker)).toMap.asJava)
-        failedSlavePartitionIds.putAll(res.failedSlaveIds.asScala.map((_, worker)).toMap.asJava)
-
-        if (!res.committedMapIdBitMap.isEmpty) {
-          committedMapIdBitmap.putAll(res.committedMapIdBitMap)
+        val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
+          (
+            masterParts.asScala
+              .filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              .map(_.getUniqueId).asJava,
+            slaveParts.asScala
+              .filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              .map(_.getUniqueId).asJava)
         }
 
-        totalWritten.add(res.totalWritten)
-        fileCount.add(res.fileCount)
-        currentShuffleFileCount.add(res.fileCount)
+        commitFiles(
+          applicationId,
+          shuffleId,
+          shuffleCommittedInfo,
+          worker,
+          masterIds,
+          slaveIds,
+          commitFilesFailedWorkers)
       }
     }
 
     def hasCommitFailedIds: Boolean = {
       val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
-      if (!pushReplicateEnabled && failedMasterPartitionIds.size() != 0) {
-        val msg = failedMasterPartitionIds.asScala.map { case (partitionId, workerInfo) =>
-          s"Lost partition $partitionId in worker [${workerInfo.readableAddress()}]"
-        }.mkString("\n")
+      if (!pushReplicateEnabled && shuffleCommittedInfo.failedMasterPartitionIds.size() != 0) {
+        val msg =
+          shuffleCommittedInfo.failedMasterPartitionIds.asScala.map {
+            case (partitionId, workerInfo) =>
+              s"Lost partition $partitionId in worker [${workerInfo.readableAddress()}]"
+          }.mkString("\n")
         logError(
           s"""
              |For shuffle $shuffleKey partition data lost:
@@ -1080,14 +1178,16 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
              |""".stripMargin)
         true
       } else {
-        val failedBothPartitionIdsToWorker = failedMasterPartitionIds.asScala.flatMap {
-          case (partitionId, worker) =>
-            if (failedSlavePartitionIds.contains(partitionId)) {
-              Some(partitionId -> (worker, failedSlavePartitionIds.get(partitionId)))
-            } else {
-              None
-            }
-        }
+        val failedBothPartitionIdsToWorker =
+          shuffleCommittedInfo.failedMasterPartitionIds.asScala.flatMap {
+            case (partitionId, worker) =>
+              if (shuffleCommittedInfo.failedSlavePartitionIds.contains(partitionId)) {
+                Some(partitionId -> (worker, shuffleCommittedInfo.failedSlavePartitionIds.get(
+                  partitionId)))
+              } else {
+                None
+              }
+          }
         if (failedBothPartitionIdsToWorker.nonEmpty) {
           val msg = failedBothPartitionIdsToWorker.map {
             case (partitionId, (masterWorker, slaveWorker)) =>
@@ -1106,26 +1206,31 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       }
     }
 
+    while (shuffleCommittedInfo.inFlightCommitRequest.get() > 0) {
+      Thread.sleep(1000)
+    }
+
     val dataLost = hasCommitFailedIds
 
     if (!dataLost) {
       val committedPartitions = new util.HashMap[String, PartitionLocation]
-      committedMasterIds.asScala.foreach { id =>
-        if (committedMasterStorageInfos.get(id) == null) {
+      shuffleCommittedInfo.committedMasterIds.asScala.foreach { id =>
+        if (shuffleCommittedInfo.committedMasterStorageInfos.get(id) == null) {
           logDebug(s"$applicationId-$shuffleId $id storage hint was not returned")
         } else {
-          masterPartMap.get(id).setStorageInfo(committedMasterStorageInfos.get(id))
-          masterPartMap.get(id).setMapIdBitMap(committedMapIdBitmap.get(id))
+          masterPartMap.get(id).setStorageInfo(
+            shuffleCommittedInfo.committedMasterStorageInfos.get(id))
+          masterPartMap.get(id).setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
           committedPartitions.put(id, masterPartMap.get(id))
         }
       }
 
-      committedSlaveIds.asScala.foreach { id =>
+      shuffleCommittedInfo.committedSlaveIds.asScala.foreach { id =>
         val slavePartition = slavePartMap.get(id)
-        if (committedSlaveStorageInfos.get(id) == null) {
+        if (shuffleCommittedInfo.committedSlaveStorageInfos.get(id) == null) {
           logDebug(s"$applicationId-$shuffleId $id storage hint was not returned")
         } else {
-          slavePartition.setStorageInfo(committedSlaveStorageInfos.get(id))
+          slavePartition.setStorageInfo(shuffleCommittedInfo.committedSlaveStorageInfos.get(id))
           val masterPartition = committedPartitions.get(id)
           if (masterPartition ne null) {
             masterPartition.setPeer(slavePartition)
@@ -1133,7 +1238,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
           } else {
             logInfo(s"Shuffle $shuffleId partition $id: master lost, " +
               s"use slave $slavePartition.")
-            slavePartition.setMapIdBitMap(committedMapIdBitmap.get(id))
+            slavePartition.setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
             committedPartitions.put(id, slavePartition)
           }
         }
@@ -1151,7 +1256,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       }
 
       logInfo(s"Shuffle $shuffleId " +
-        s"commit files complete. File count ${currentShuffleFileCount.sum()} " +
+        s"commit files complete. File count ${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
         s"using ${(System.nanoTime() - commitFileStartTime) / 1000000} ms")
     }
 
@@ -1178,6 +1283,96 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
       ReleaseSlots(applicationId, shuffleId, List.empty.asJava, List.empty.asJava))
   }
 
+  private def commitFiles(
+      applicationId: String,
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      worker: WorkerInfo,
+      masterIds: util.List[String],
+      slaveIds: util.List[String],
+      commitFilesFailedWorkers: ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]): Unit = {
+
+    val res =
+      if (!testRetryCommitFiles) {
+        val commitFiles = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds,
+          slaveIds,
+          shuffleMapperAttempts.get(shuffleId),
+          commitEpoch.incrementAndGet())
+        shuffleCommittedInfo.inFlightCommitRequest.incrementAndGet()
+        val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
+        shuffleCommittedInfo.inFlightCommitRequest.decrementAndGet()
+
+        res.status match {
+          case StatusCode.SUCCESS => // do nothing
+          case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.FAILED =>
+            logDebug(s"Request $commitFiles return ${res.status} for " +
+              s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
+            commitFilesFailedWorkers.put(worker, (res.status, System.currentTimeMillis()))
+          case _ => // won't happen
+        }
+        res
+      } else {
+        // for test
+        val commitFiles1 = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds.subList(0, masterIds.size() / 2),
+          slaveIds.subList(0, slaveIds.size() / 2),
+          shuffleMapperAttempts.get(shuffleId),
+          commitEpoch.incrementAndGet())
+        val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
+
+        val commitFiles = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds.subList(masterIds.size() / 2, masterIds.size()),
+          slaveIds.subList(slaveIds.size() / 2, slaveIds.size()),
+          shuffleMapperAttempts.get(shuffleId),
+          commitEpoch.incrementAndGet())
+        val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
+
+        res1.committedMasterStorageInfos.putAll(res2.committedMasterStorageInfos)
+        res1.committedSlaveStorageInfos.putAll(res2.committedSlaveStorageInfos)
+        res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
+        CommitFilesResponse(
+          status = if (res1.status == StatusCode.SUCCESS) res2.status else res1.status,
+          (res1.committedMasterIds.asScala ++ res2.committedMasterIds.asScala).toList.asJava,
+          (res1.committedSlaveIds.asScala ++ res1.committedSlaveIds.asScala).toList.asJava,
+          (res1.failedMasterIds.asScala ++ res1.failedMasterIds.asScala).toList.asJava,
+          (res1.failedSlaveIds.asScala ++ res2.failedSlaveIds.asScala).toList.asJava,
+          res1.committedMasterStorageInfos,
+          res1.committedSlaveStorageInfos,
+          res1.committedMapIdBitMap,
+          res1.totalWritten + res2.totalWritten,
+          res1.fileCount + res2.fileCount)
+      }
+
+    shuffleCommittedInfo.synchronized {
+      // record committed partitionIds
+      shuffleCommittedInfo.committedMasterIds.addAll(res.committedMasterIds)
+      shuffleCommittedInfo.committedSlaveIds.addAll(res.committedSlaveIds)
+
+      // record committed partitions storage hint and disk hint
+      shuffleCommittedInfo.committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
+      shuffleCommittedInfo.committedSlaveStorageInfos.putAll(res.committedSlaveStorageInfos)
+
+      // record failed partitions
+      shuffleCommittedInfo.failedMasterPartitionIds.putAll(
+        res.failedMasterIds.asScala.map((_, worker)).toMap.asJava)
+      shuffleCommittedInfo.failedSlavePartitionIds.putAll(
+        res.failedSlaveIds.asScala.map((_, worker)).toMap.asJava)
+
+      shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
+
+      totalWritten.add(res.totalWritten)
+      fileCount.add(res.fileCount)
+      shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+    }
+  }
+
   private def handleUnregisterShuffle(
       appId: String,
       shuffleId: Int): Unit = {
@@ -1586,6 +1781,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
         stageEndShuffleSet.remove(shuffleId)
         changePartitionRequests.remove(shuffleId)
         inBatchPartitions.remove(shuffleId)
+        committedPartitionInfo.remove(shuffleId)
         unregisterShuffleTime.remove(shuffleId)
         shuffleAllocatedWorkers.remove(shuffleId)
         latestPartitionLocation.remove(shuffleId)
@@ -1699,7 +1895,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
           return endpoint.askSync[CommitFilesResponse](message)
         }
       } catch {
-        case e: Exception =>
+        case e: Throwable =>
           retryTimes += 1
           logError(
             s"AskSync CommitFiles for ${message.shuffleId} failed (attempt $retryTimes/$maxRetries).",
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 8881a9cb..8d07d4b6 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -670,6 +670,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
   def batchHandleChangePartitionEnabled: Boolean = get(BATCH_HANDLE_CHANGE_PARTITION_ENABLED)
   def batchHandleChangePartitionNumThreads: Int = get(BATCH_HANDLE_CHANGE_PARTITION_THREADS)
   def batchHandleChangePartitionRequestInterval: Long = get(BATCH_HANDLE_CHANGE_PARTITION_INTERVAL)
+  def batchHandleCommitPartitionEnabled: Boolean = get(BATCH_HANDLE_COMMIT_PARTITION_ENABLED)
+  def batchHandleCommitPartitionNumThreads: Int = get(BATCH_HANDLE_COMMIT_PARTITION_THREADS)
+  def batchHandleCommitPartitionRequestInterval: Long = get(BATCH_HANDLED_COMMIT_PARTITION_INTERVAL)
   def rpcCacheSize: Int = get(RPC_CACHE_SIZE)
   def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL)
   def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME)
@@ -2129,6 +2132,31 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("100ms")
 
+  val BATCH_HANDLE_COMMIT_PARTITION_ENABLED: ConfigEntry[Boolean] =
+    buildConf("celeborn.shuffle.batchHandleCommitPartition.enabled")
+      .categories("client")
+      .doc("When true, LifecycleManager will handle commit partition request in batch. " +
+        "Otherwise, LifecycleManager won't commit partition before stage end")
+      .version("0.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val BATCH_HANDLE_COMMIT_PARTITION_THREADS: ConfigEntry[Int] =
+    buildConf("celeborn.shuffle.batchHandleCommitPartition.threads")
+      .categories("client")
+      .doc("Threads number for LifecycleManager to handle commit partition request in batch.")
+      .version("0.2.0")
+      .intConf
+      .createWithDefault(8)
+
+  val BATCH_HANDLED_COMMIT_PARTITION_INTERVAL: ConfigEntry[Long] =
+    buildConf("celeborn.shuffle.batchHandleCommitPartition.interval")
+      .categories("client")
+      .doc("Interval for LifecycleManager to schedule handling commit partition requests in batch.")
+      .version("0.2.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("5s")
+
   val PORT_MAX_RETRY: ConfigEntry[Int] =
     buildConf("celeborn.port.maxRetries")
       .withAlternative("rss.master.port.maxretry")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 3910f1d7..d322cf9d 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -45,6 +45,9 @@ license: |
 | celeborn.shuffle.batchHandleChangePartition.enabled | false | When true, LifecycleManager will handle change partition request in batch. Otherwise, LifecycleManager will process the requests one by one | 0.2.0 | 
 | celeborn.shuffle.batchHandleChangePartition.interval | 100ms | Interval for LifecycleManager to schedule handling change partition requests in batch. | 0.2.0 | 
 | celeborn.shuffle.batchHandleChangePartition.threads | 8 | Threads number for LifecycleManager to handle change partition request in batch. | 0.2.0 | 
+| celeborn.shuffle.batchHandleCommitPartition.enabled | false | When true, LifecycleManager will handle commit partition request in batch. Otherwise, LifecycleManager won't commit partition before stage end | 0.2.0 | 
+| celeborn.shuffle.batchHandleCommitPartition.interval | 5s | Interval for LifecycleManager to schedule handling commit partition requests in batch. | 0.2.0 | 
+| celeborn.shuffle.batchHandleCommitPartition.threads | 8 | Threads number for LifecycleManager to handle commit partition request in batch. | 0.2.0 | 
 | celeborn.shuffle.chuck.size | 8m | Max chunk size of reducer's merged shuffle data. For example, if a reducer's shuffle data is 128M and the data will need 16 fetch chunk requests to fetch. | 0.2.0 | 
 | celeborn.shuffle.compression.codec | LZ4 | The codec used to compress shuffle data. By default, Celeborn provides two codecs: `lz4` and `zstd`. | 0.2.0 | 
 | celeborn.shuffle.compression.zstd.level | 1 | Compression level for Zstd compression codec, its value should be an integer between -5 and 22. Increasing the compression level will result in better compression at the expense of more CPU and memory. | 0.2.0 | 
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 4cbeecf6..c615f88b 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -36,7 +36,7 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.unsafe.Platform
-import org.apache.celeborn.service.deploy.worker.storage.{FileWriter, LocalFlusher}
+import org.apache.celeborn.service.deploy.worker.storage.{FileWriter, LocalFlusher, StorageManager}
 
 class PushDataHandler extends BaseMessageHandler with Logging {
 
@@ -52,6 +52,7 @@ class PushDataHandler extends BaseMessageHandler with Logging {
   var diskReserveSize: Long = _
   var partitionSplitMinimumSize: Long = _
   var shutdown: AtomicBoolean = _
+  var storageManager: StorageManager = _
 
   def init(worker: Worker): Unit = {
     workerSource = worker.workerSource
@@ -65,6 +66,7 @@ class PushDataHandler extends BaseMessageHandler with Logging {
     workerInfo = worker.workerInfo
     diskReserveSize = worker.conf.diskReserveSize
     partitionSplitMinimumSize = worker.conf.partitionSplitMinimumSize
+    storageManager = worker.storageManager
     shutdown = worker.shutdown
 
     logInfo(s"diskReserveSize $diskReserveSize")
@@ -171,11 +173,21 @@ class PushDataHandler extends BaseMessageHandler with Logging {
           wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
         }
       } else {
-        val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId, " +
-          s"attempt $attemptId, uniqueId ${pushData.partitionUniqueId})."
-        logWarning(s"[handlePushData] $msg")
-        callback.onFailure(
-          new Exception(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND.getMessage()))
+        if (storageManager.shuffleKeySet().contains(shuffleKey)) {
+          // If there is no shuffle key in shuffleMapperAttempts but there is shuffle key
+          // in StorageManager. This partition should be HARD_SPLIT partition and
+          // after worker restart, some task still push data to this HARD_SPLIT partition.
+          logInfo(
+            s"Receive push data for committed hard split partition of (shuffle $shuffleKey, " +
+              s"map $mapId attempt $attemptId)")
+          wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+        } else {
+          val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId, " +
+            s"attempt $attemptId, uniqueId ${pushData.partitionUniqueId})."
+          logWarning(s"[handlePushData] $msg")
+          callback.onFailure(
+            new Exception(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND.getMessage()))
+        }
       }
       return
     }
@@ -339,10 +351,20 @@ class PushDataHandler extends BaseMessageHandler with Logging {
             wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
           }
         } else {
-          val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId," +
-            s" attempt $attemptId, uniqueId $id)."
-          logWarning(s"[handlePushMergedData] $msg")
-          callback.onFailure(new Exception(msg))
+          if (storageManager.shuffleKeySet().contains(shuffleKey)) {
+            // If there is no shuffle key in shuffleMapperAttempts but there is shuffle key
+            // in StorageManager. This partition should be HARD_SPLIT partition and
+            // after worker restart, some task still push data to this HARD_SPLIT partition.
+            logInfo(
+              s"Receive push merged data for committed hard split partition of (shuffle $shuffleKey, " +
+                s"map $mapId attempt $attemptId)")
+            wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+          } else {
+            val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId," +
+              s" attempt $attemptId, uniqueId $id)."
+            logWarning(s"[handlePushMergedData] $msg")
+            callback.onFailure(new Exception(msg))
+          }
         }
         return
       }