You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2022/01/05 07:47:39 UTC

[spark] branch master updated: [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization for push-based shuffle

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6128a6  [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization for push-based shuffle
f6128a6 is described below

commit f6128a6f4215dc45a19209d799dd9bf98fab6d8a
Author: Venkata krishnan Sowrirajan <vs...@linkedin.com>
AuthorDate: Wed Jan 5 01:47:01 2022 -0600

    [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization for push-based shuffle
    
    ### What changes were proposed in this pull request?
    
    As part of SPARK-32920 implemented a simple approach to finalization for push-based shuffle. Shuffle merge finalization is the final operation happens at the end of the stage when all the tasks are completed asking all the external shuffle services to complete the shuffle merge for the stage. Once this request is completed no more shuffle pushes will be accepted. With this approach, `DAGScheduler` waits for a fixed time of 10s (`spark.shuffle.push.finalize.timeout`) to allow some time [...]
    
    In this PR, instead of waiting for fixed amount of time before shuffle merge finalization now this is controlled adaptively if min threshold number of map tasks shuffle push (`spark.shuffle.push.minPushRatio`) completed then shuffle merge finalization will be scheduled. Also additionally if the total shuffle generated is lesser than min threshold shuffle size (`spark.shuffle.push.minShuffleSizeToWait`) then immediately shuffle merge finalization is scheduled.
    ### Why are the changes needed?
    
    This is a performance improvement to the existing functionality
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes additional user facing configs `spark.shuffle.push.minPushRatio` and `spark.shuffle.push.minShuffleSizeToWait`
    
    ### How was this patch tested?
    
    Added unit tests in `DAGSchedulerSuite`, `ShuffleBlockPusherSuite`
    
    Lead-authored-by: Min Shen <mshenlinkedin.com>
    Co-authored-by: Venkata krishnan Sowrirajan <vsowrirajanlinkedin.com>
    
    Closes #33896 from venkata91/SPARK-33701.
    
    Lead-authored-by: Venkata krishnan Sowrirajan <vs...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/Dependency.scala   |  35 ++-
 .../scala/org/apache/spark/MapOutputTracker.scala  |   6 +-
 .../src/main/scala/org/apache/spark/SparkEnv.scala |   3 +
 .../executor/CoarseGrainedExecutorBackend.scala    |   6 +
 .../org/apache/spark/internal/config/package.scala |  27 ++
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 278 +++++++++++++----
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   4 +
 .../cluster/CoarseGrainedClusterMessage.scala      |   3 +
 .../cluster/CoarseGrainedSchedulerBackend.scala    |   3 +
 .../apache/spark/shuffle/ShuffleBlockPusher.scala  |  39 ++-
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 340 +++++++++++++++++++--
 .../spark/shuffle/ShuffleBlockPusherSuite.scala    | 101 +++++-
 docs/configuration.md                              |  16 +
 13 files changed, 772 insertions(+), 89 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 1b4e7ba..8e348ee 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark
 
+import java.util.concurrent.ScheduledFuture
+
 import scala.reflect.ClassTag
 
+import org.roaringbitmap.RoaringBitmap
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
@@ -131,9 +135,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
   def shuffleMergeId: Int = _shuffleMergeId
 
   def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
-    if (mergerLocs != null) {
-      this.mergerLocs = mergerLocs
-    }
+    this.mergerLocs = mergerLocs
   }
 
   def getMergerLocs: Seq[BlockManagerId] = mergerLocs
@@ -160,6 +162,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
     _shuffleMergedFinalized = false
     mergerLocs = Nil
     _shuffleMergeId += 1
+    finalizeTask = None
+    shufflePushCompleted.clear()
   }
 
   private def canShuffleMergeBeEnabled(): Boolean = {
@@ -169,11 +173,34 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
     if (isPushShuffleEnabled && rdd.isBarrier()) {
       logWarning("Push-based shuffle is currently not supported for barrier stages")
     }
-    isPushShuffleEnabled &&
+    isPushShuffleEnabled && numPartitions > 0 &&
       // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
       !rdd.isBarrier()
   }
 
+  @transient private[this] val shufflePushCompleted = new RoaringBitmap()
+
+  /**
+   * Mark a given map task as push completed in the tracking bitmap.
+   * Using the bitmap ensures that the same map task launched multiple times due to
+   * either speculation or stage retry is only counted once.
+   * @param mapIndex Map task index
+   * @return number of map tasks with block push completed
+   */
+  def incPushCompleted(mapIndex: Int): Int = {
+    shufflePushCompleted.add(mapIndex)
+    shufflePushCompleted.getCardinality
+  }
+
+  // Only used by DAGScheduler to coordinate shuffle merge finalization
+  @transient private[this] var finalizeTask: Option[ScheduledFuture[_]] = None
+
+  def getFinalizeTask: Option[ScheduledFuture[_]] = finalizeTask
+
+  def setFinalizeTask(task: ScheduledFuture[_]): Unit = {
+    finalizeTask = Option(task)
+  }
+
   _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
   _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
 }
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index af26abc..d71fb09 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -917,7 +917,7 @@ private[spark] class MapOutputTrackerMaster(
         Runtime.getRuntime.availableProcessors(),
         statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
       if (parallelism <= 1) {
-        for (s <- statuses) {
+        statuses.filter(_ != null).foreach { s =>
           for (i <- 0 until totalSizes.length) {
             totalSizes(i) += s.getSizeForBlock(i)
           }
@@ -928,8 +928,8 @@ private[spark] class MapOutputTrackerMaster(
           implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
           val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
             reduceIds => Future {
-              for (s <- statuses; i <- reduceIds) {
-                totalSizes(i) += s.getSizeForBlock(i)
+              statuses.filter(_ != null).foreach { s =>
+                reduceIds.foreach(i => totalSizes(i) += s.getSizeForBlock(i))
               }
             }
           }
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0388c7b..d07614a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.PythonWorkerFactory
 import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorBackend
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.internal.config._
 import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager}
@@ -81,6 +82,8 @@ class SparkEnv (
 
   private[spark] var driverTmpDir: Option[String] = None
 
+  private[spark] var executorBackend: Option[ExecutorBackend] = None
+
   private[spark] def stop(): Unit = {
 
     if (!isStopped) {
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 43887a7..fb7b4e6 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -106,6 +106,7 @@ private[spark] class CoarseGrainedExecutorBackend(
     rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
       // This is a very fast action so we can use "ThreadUtils.sameThread"
       driver = Some(ref)
+      env.executorBackend = Option(this)
       ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls,
         extractAttributes, _resources, resourceProfile.id))
     }(ThreadUtils.sameThread).onComplete {
@@ -162,6 +163,11 @@ private[spark] class CoarseGrainedExecutorBackend(
       .map(e => (e._1.substring(prefix.length).toUpperCase(Locale.ROOT), e._2)).toMap
   }
 
+  def notifyDriverAboutPushCompletion(shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    val msg = ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex)
+    driver.foreach(_.send(msg))
+  }
+
   override def receive: PartialFunction[Any, Unit] = {
     case RegisteredExecutor =>
       logInfo("Successfully registered with driver")
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 71a11f6..a942ba5 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2193,6 +2193,33 @@ package object config {
       // with small MB sized chunk of data.
       .createWithDefaultString("3m")
 
+  private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS =
+    ConfigBuilder("spark.shuffle.push.merge.finalizeThreads")
+      .doc("Number of threads used by driver to finalize shuffle merge. Since it could" +
+        " potentially take seconds for a large shuffle to finalize, having multiple threads helps" +
+        " driver to handle concurrent shuffle merge finalize requests when push-based" +
+        " shuffle is enabled.")
+      .version("3.3.0")
+      .intConf
+      .createWithDefault(3)
+
+  private[spark] val PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT =
+    ConfigBuilder("spark.shuffle.push.minShuffleSizeToWait")
+      .doc("Driver will wait for merge finalization to complete only if total shuffle size is" +
+        " more than this threshold. If total shuffle size is less, driver will immediately" +
+        " finalize the shuffle output")
+      .version("3.3.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("500m")
+
+  private[spark] val PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO =
+    ConfigBuilder("spark.shuffle.push.minCompletedPushRatio")
+      .doc("Fraction of map partitions that should be push complete before driver starts" +
+        " shuffle merge finalization during push based shuffle")
+      .version("3.3.0")
+      .doubleConf
+      .createWithDefault(1.0)
+
   private[spark] val JAR_IVY_REPO_PATH =
     ConfigBuilder("spark.jars.ivy")
       .doc("Path to specify the Ivy user directory, used for the local Ivy cache and " +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4ed734c..eed71038 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
 
 import java.io.NotSerializableException
 import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeoutException, TimeUnit }
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.annotation.tailrec
@@ -265,6 +265,14 @@ private[spark] class DAGScheduler(
   private val shuffleMergeFinalizeWaitSec =
     sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT)
 
+  private val shuffleMergeWaitMinSizeThreshold =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT)
+
+  private val shufflePushMinRatio = sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO)
+
+  private val shuffleMergeFinalizeNumThreads =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS)
+
   // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be
   // initialized lazily
   private lazy val externalShuffleClient: Option[BlockStoreClient] =
@@ -274,8 +282,12 @@ private[spark] class DAGScheduler(
       None
     }
 
+  // Use multi-threaded scheduled executor. The merge finalization task could take some time,
+  // depending on the time to establish connections to mergers, and the time to get MergeStatuses
+  // from all the mergers.
   private val shuffleMergeFinalizeScheduler =
-    ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8)
+    ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer",
+      shuffleMergeFinalizeNumThreads)
 
   /**
    * Called by the TaskSetManager to report task's starting.
@@ -1065,6 +1077,14 @@ private[spark] class DAGScheduler(
   }
 
   /**
+   * Receives notification about shuffle push for a given shuffle from one map
+   * task has completed
+   */
+  def shufflePushCompleted(shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    eventProcessLoop.post(ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex))
+  }
+
+  /**
    * Kill a given task. It will be retried.
    *
    * @return Whether the task was successfully killed.
@@ -1407,7 +1427,7 @@ private[spark] class DAGScheduler(
             // merger locations but the corresponding shuffle map stage did not complete
             // successfully, we would still enable push for its retry.
             s.shuffleDep.setShuffleMergeEnabled(false)
-            logInfo("Push-based shuffle disabled for $stage (${stage.name}) since it" +
+            logInfo(s"Push-based shuffle disabled for $stage (${stage.name}) since it" +
               " is already shuffle merge finalized")
           }
         }
@@ -1636,6 +1656,42 @@ private[spark] class DAGScheduler(
     }
   }
 
+  private[scheduler] def checkAndScheduleShuffleMergeFinalize(
+      shuffleStage: ShuffleMapStage): Unit = {
+    // Check if a finalize task has already been scheduled. This is to prevent scenarios
+    // where we don't schedule multiple shuffle merge finalization which can happen due to
+    // stage retry or shufflePushMinRatio is already hit etc.
+    if (shuffleStage.shuffleDep.getFinalizeTask.isEmpty) {
+      // 1. Stage indeterminate and some map outputs are not available - finalize
+      // immediately without registering shuffle merge results.
+      // 2. Stage determinate and some map outputs are not available - decide to
+      // register merge results based on map outputs size available and
+      // shuffleMergeWaitMinSizeThreshold.
+      // 3. All shuffle outputs available - decide to register merge results based
+      // on map outputs size available and shuffleMergeWaitMinSizeThreshold.
+      val totalSize = {
+        lazy val computedTotalSize =
+          mapOutputTracker.getStatistics(shuffleStage.shuffleDep).
+            bytesByPartitionId.filter(_ > 0).sum
+        if (shuffleStage.isAvailable) {
+          computedTotalSize
+        } else {
+          if (shuffleStage.isIndeterminate) {
+            0L
+          } else {
+            computedTotalSize
+          }
+        }
+      }
+
+      if (totalSize < shuffleMergeWaitMinSizeThreshold) {
+        scheduleShuffleMergeFinalize(shuffleStage, delay = 0, registerMergeResults = false)
+      } else {
+        scheduleShuffleMergeFinalize(shuffleStage, shuffleMergeFinalizeWaitSec)
+      }
+    }
+  }
+
   /**
    * Responds to a task finishing. This is called inside the event loop so it assumes that it can
    * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -1767,7 +1823,7 @@ private[spark] class DAGScheduler(
             if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
               if (!shuffleStage.shuffleDep.shuffleMergeFinalized &&
                 shuffleStage.shuffleDep.getMergerLocs.nonEmpty) {
-                scheduleShuffleMergeFinalize(shuffleStage)
+                checkAndScheduleShuffleMergeFinalize(shuffleStage)
               } else {
                 processShuffleMapStageCompletion(shuffleStage)
               }
@@ -2074,20 +2130,63 @@ private[spark] class DAGScheduler(
   }
 
   /**
-   * Schedules shuffle merge finalize.
+   *
+   * Schedules shuffle merge finalization.
+   *
+   * @param stage the stage to finalize shuffle merge
+   * @param delay how long to wait before finalizing shuffle merge
+   * @param registerMergeResults indicate whether DAGScheduler would register the received
+   *                             MergeStatus with MapOutputTracker and wait to schedule the reduce
+   *                             stage until MergeStatus have been received from all mergers or
+   *                             reaches timeout. For very small shuffle, this could be set to
+   *                             false to avoid impact to job runtime.
    */
-  private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): Unit = {
-    // TODO: SPARK-33701: Instead of waiting for a constant amount of time for finalization
-    // TODO: for all the stages, adaptively tune timeout for merge finalization
-    logInfo(("%s (%s) scheduled for finalizing" +
-      " shuffle merge in %s s").format(stage, stage.name, shuffleMergeFinalizeWaitSec))
-    shuffleMergeFinalizeScheduler.schedule(
-      new Runnable {
-        override def run(): Unit = finalizeShuffleMerge(stage)
-      },
-      shuffleMergeFinalizeWaitSec,
-      TimeUnit.SECONDS
-    )
+  private[scheduler] def scheduleShuffleMergeFinalize(
+      stage: ShuffleMapStage,
+      delay: Long,
+      registerMergeResults: Boolean = true): Unit = {
+    val shuffleDep = stage.shuffleDep
+    val scheduledTask: Option[ScheduledFuture[_]] = shuffleDep.getFinalizeTask
+    scheduledTask match {
+      case Some(task) =>
+        // If we find an already scheduled task, check if the task has been triggered yet.
+        // If it's already triggered, do nothing. Otherwise, cancel it and schedule a new
+        // one for immediate execution. Note that we should get here only when
+        // handleShufflePushCompleted schedules a finalize task after the shuffle map stage
+        // completed earlier and scheduled a task with default delay.
+        // The current task should be coming from handleShufflePushCompleted, thus the
+        // delay should be 0 and registerMergeResults should be true.
+        assert(delay == 0 && registerMergeResults)
+        if (task.getDelay(TimeUnit.NANOSECONDS) > 0 && task.cancel(false)) {
+          logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge immediately " +
+            s"after cancelling previously scheduled task.")
+          shuffleDep.setFinalizeTask(
+            shuffleMergeFinalizeScheduler.schedule(
+              new Runnable {
+                override def run(): Unit = finalizeShuffleMerge(stage, registerMergeResults)
+              },
+              0,
+              TimeUnit.SECONDS
+            )
+          )
+        } else {
+          logInfo(s"$stage (${stage.name}) existing scheduled task for finalizing shuffle merge" +
+            s"would either be in-progress or finished. No need to schedule shuffle merge" +
+            s" finalization again.")
+        }
+      case None =>
+        // If no previous finalization task is scheduled, schedule the finalization task.
+        logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge in $delay s")
+        shuffleDep.setFinalizeTask(
+          shuffleMergeFinalizeScheduler.schedule(
+            new Runnable {
+              override def run(): Unit = finalizeShuffleMerge(stage, registerMergeResults)
+            },
+            delay,
+            TimeUnit.SECONDS
+          )
+        )
+    }
   }
 
   /**
@@ -2095,38 +2194,72 @@ private[spark] class DAGScheduler(
    * the given shuffle map stage to finalize the shuffle merge process for this shuffle. This is
    * invoked in a separate thread to reduce the impact on the DAGScheduler main thread, as the
    * scheduler might need to talk to 1000s of shuffle services to finalize shuffle merge.
+   *
+   * @param stage ShuffleMapStage to finalize shuffle merge for
+   * @param registerMergeResults indicate whether DAGScheduler would register the received
+   *                             MergeStatus with MapOutputTracker and wait to schedule the reduce
+   *                             stage until MergeStatus have been received from all mergers or
+   *                             reaches timeout. For very small shuffle, this could be set to
+   *                             false to avoid impact to job runtime.
    */
-  private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
-    logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+  private[scheduler] def finalizeShuffleMerge(
+      stage: ShuffleMapStage,
+      registerMergeResults: Boolean = true): Unit = {
+    logInfo(s"$stage (${stage.name}) finalizing the shuffle merge with registering merge " +
+      s"results set to $registerMergeResults")
+    val shuffleId = stage.shuffleDep.shuffleId
+    val shuffleMergeId = stage.shuffleDep.shuffleMergeId
+    val numMergers = stage.shuffleDep.getMergerLocs.length
+    val results = (0 until numMergers).map(_ => SettableFuture.create[Boolean]())
     externalShuffleClient.foreach { shuffleClient =>
-      val shuffleId = stage.shuffleDep.shuffleId
-      val numMergers = stage.shuffleDep.getMergerLocs.length
-      val results = (0 until numMergers).map(_ => SettableFuture.create[Boolean]())
-
-      stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
-        case (shuffleServiceLoc, index) =>
-          // Sends async request to shuffle service to finalize shuffle merge on that host
-          // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled
-          // TODO: during shuffleMergeFinalizeWaitSec
-          shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
-            shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId,
-            new MergeFinalizerListener {
-              override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
-                assert(shuffleId == statuses.shuffleId)
-                eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
-                  convertMergeStatusesToMergeStatusArr(statuses, shuffleServiceLoc)))
-                results(index).set(true)
-              }
+      if (!registerMergeResults) {
+        results.foreach(_.set(true))
+        // Finalize in separate thread as shuffle merge is a no-op in this case
+        shuffleMergeFinalizeScheduler.schedule(new Runnable {
+          override def run(): Unit = {
+            stage.shuffleDep.getMergerLocs.foreach {
+              case shuffleServiceLoc =>
+                // Sends async request to shuffle service to finalize shuffle merge on that host.
+                // Since merge statuses will not be registered in this case,
+                // we pass a no-op listener.
+                shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+                  shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+                  new MergeFinalizerListener {
+                    override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
+                    }
 
-              override def onShuffleMergeFailure(e: Throwable): Unit = {
-                logWarning(s"Exception encountered when trying to finalize shuffle " +
-                  s"merge on ${shuffleServiceLoc.host} for shuffle $shuffleId", e)
-                // Do not fail the future as this would cause dag scheduler to prematurely
-                // give up on waiting for merge results from the remaining shuffle services
-                // if one fails
-                results(index).set(false)
-              }
-            })
+                    override def onShuffleMergeFailure(e: Throwable): Unit = {
+                    }
+                  })
+            }
+          }
+        }, 0, TimeUnit.SECONDS)
+      } else {
+        stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+          case (shuffleServiceLoc, index) =>
+            // Sends async request to shuffle service to finalize shuffle merge on that host
+            // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled
+            // TODO: during shuffleMergeFinalizeWaitSec
+            shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+              shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+              new MergeFinalizerListener {
+                override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
+                  assert(shuffleId == statuses.shuffleId)
+                  eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
+                    convertMergeStatusesToMergeStatusArr(statuses, shuffleServiceLoc)))
+                  results(index).set(true)
+                }
+
+                override def onShuffleMergeFailure(e: Throwable): Unit = {
+                  logWarning(s"Exception encountered when trying to finalize shuffle " +
+                    s"merge on ${shuffleServiceLoc.host} for shuffle $shuffleId", e)
+                  // Do not fail the future as this would cause dag scheduler to prematurely
+                  // give up on waiting for merge results from the remaining shuffle services
+                  // if one fails
+                  results(index).set(false)
+                }
+              })
+        }
       }
       // DAGScheduler only waits for a limited amount of time for the merge results.
       // It will attempt to submit the next stage(s) irrespective of whether merge results
@@ -2185,15 +2318,45 @@ private[spark] class DAGScheduler(
     }
   }
 
-  private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage): Unit = {
-    // Only update MapOutputTracker metadata if the stage is still active. i.e not cancelled.
-    if (runningStages.contains(stage)) {
-      stage.shuffleDep.markShuffleMergeFinalized()
-      processShuffleMapStageCompletion(stage)
-    } else {
-      // Unregister all merge results if the stage is currently not
-      // active (i.e. the stage is cancelled)
-      mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+  private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage,
+        shuffleMergeId: Int): Unit = {
+    // Check if update is for the same merge id - finalization might have completed for an earlier
+    // adaptive attempt while the stage might have failed/killed and shuffle id is getting
+    // re-executing now.
+    if (stage.shuffleDep.shuffleMergeId == shuffleMergeId) {
+      if (stage.pendingPartitions.isEmpty) {
+        if (runningStages.contains(stage)) {
+          stage.shuffleDep.markShuffleMergeFinalized()
+          processShuffleMapStageCompletion(stage)
+        } else {
+          // Unregister all merge results if the stage is currently not
+          // active (i.e. the stage is cancelled)
+          mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+        }
+      } else {
+        // stage still running, mark merge finalized. Stage completion will invoke
+        // processShuffleMapStageCompletion
+        stage.shuffleDep.markShuffleMergeFinalized()
+      }
+    }
+  }
+
+  private[scheduler] def handleShufflePushCompleted(
+      shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    shuffleIdToMapStage.get(shuffleId) match {
+      case Some(mapStage) =>
+        val shuffleDep = mapStage.shuffleDep
+        // Only update shufflePushCompleted events for the current active stage map tasks.
+        // This is required to prevent shuffle merge finalization by dangling tasks of a
+        // previous attempt in the case of indeterminate stage.
+        if (shuffleDep.shuffleMergeId == shuffleMergeId) {
+          if (!shuffleDep.shuffleMergeFinalized &&
+            shuffleDep.incPushCompleted(mapIndex).toDouble / shuffleDep.rdd.partitions.length
+              >= shufflePushMinRatio) {
+            scheduleShuffleMergeFinalize(mapStage, delay = 0)
+          }
+        }
+      case None =>
     }
   }
 
@@ -2649,7 +2812,10 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
       dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses)
 
     case ShuffleMergeFinalized(stage) =>
-      dagScheduler.handleShuffleMergeFinalized(stage)
+      dagScheduler.handleShuffleMergeFinalized(stage, stage.shuffleDep.shuffleMergeId)
+
+    case ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex) =>
+      dagScheduler.handleShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex)
   }
 
   override def onError(e: Throwable): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 307844c..f3798da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -111,3 +111,7 @@ private[scheduler] case class RegisterMergeStatuses(
 
 private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
   extends DAGSchedulerEvent
+
+private[scheduler] case class ShufflePushCompleted(
+    shuffleId: Int, shuffleMergeId: Int, mapIndex: Int)
+  extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 66ac40f..61ee865 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -85,6 +85,9 @@ private[spark] object CoarseGrainedClusterMessages {
     }
   }
 
+  case class ShufflePushCompletion(shuffleId: Int, shuffleMergeId: Int, mapIndex: Int)
+    extends CoarseGrainedClusterMessage
+
   // Internal messages in driver
   case object ReviveOffers extends CoarseGrainedClusterMessage
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 326ea83..13a7183 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -168,6 +168,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
           }
         }
 
+      case ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex) =>
+        scheduler.dagScheduler.shufflePushCompleted(shuffleId, shuffleMergeId, mapIndex)
+
       case ReviveOffers =>
         makeOffers()
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
index 8790371..d6972cd 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
 
 import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
 import org.apache.spark.annotation.Since
+import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.launcher.SparkLauncher
@@ -53,7 +54,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
   private[this] val maxBytesInFlight = conf.get(REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024
   private[this] val maxReqsInFlight = conf.get(REDUCER_MAX_REQS_IN_FLIGHT)
   private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
-  private[this] var bytesInFlight = 0L
+  private[shuffle] var bytesInFlight = 0L
   private[this] var reqsInFlight = 0
   private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
   private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]()
@@ -61,6 +62,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
   private[this] val errorHandler = createErrorHandler()
   // VisibleForTesting
   private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+  private[this] var shuffleId = -1
+  private[this] var mapIndex = -1
+  private[this] var shuffleMergeId = -1
+  private[this] var pushCompletionNotified = false
 
   // VisibleForTesting
   private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
@@ -84,6 +89,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
       }
     }
   }
+  // VisibleForTesting
+  private[shuffle] def isPushCompletionNotified = pushCompletionNotified
 
   /**
    * Initiates the block push.
@@ -101,11 +108,17 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
       mapIndex: Int): Unit = {
     val numPartitions = dep.partitioner.numPartitions
     val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+    this.shuffleId = dep.shuffleId
+    this.shuffleMergeId = dep.shuffleMergeId
+    this.mapIndex = mapIndex
     val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId,
       dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, transportConf)
     // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
     // time won't be pushing the same ranges of shuffle partitions.
     pushRequests ++= Utils.randomize(requests)
+    if (pushRequests.isEmpty) {
+      notifyDriverAboutPushCompletion()
+    }
 
     submitTask(() => {
       tryPushUpToMax()
@@ -327,11 +340,35 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
         s"stop.")
       return false
     } else {
+      if (reqsInFlight <= 0 && pushRequests.isEmpty && deferredPushRequests.isEmpty) {
+        notifyDriverAboutPushCompletion()
+      }
       remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty)
     }
   }
 
   /**
+   * Notify the driver about all the blocks generated by the current map task having been pushed.
+   * This enables the DAGScheduler to finalize shuffle merge as soon as sufficient map tasks have
+   * completed push instead of always waiting for a fixed amount of time.
+   *
+   * VisibleForTesting
+   */
+  protected def notifyDriverAboutPushCompletion(): Unit = {
+    assert(shuffleId >= 0 && mapIndex >= 0)
+    if (!pushCompletionNotified) {
+      SparkEnv.get.executorBackend match {
+        case Some(cb: CoarseGrainedExecutorBackend) =>
+          cb.notifyDriverAboutPushCompletion(shuffleId, shuffleMergeId, mapIndex)
+        case Some(eb: ExecutorBackend) =>
+          logWarning(s"Currently $eb doesn't support push-based shuffle")
+        case None =>
+      }
+      pushCompletionNotified = true
+    }
+  }
+
+  /**
    * Convert the shuffle data file of the current mapper into a list of PushRequest. Basically,
    * continuous blocks in the shuffle file are grouped into a single request to allow more
    * efficient read of the block data. Each mapper for a given shuffle will receive the same
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index afea912..76612cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.util.Properties
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
 
 import scala.annotation.meta.param
@@ -124,6 +124,31 @@ class MyRDD(
   override def toString: String = "DAGSchedulerSuiteRDD " + id
 }
 
+class DummyScheduledFuture(
+    val delay: Long,
+    val registerMergeResults: Boolean)
+  extends ScheduledFuture[Int] {
+
+  override def get(timeout: Long, unit: TimeUnit): Int =
+    throw new IllegalStateException("should not be reached")
+
+  override def getDelay(unit: TimeUnit): Long = delay
+
+  override def compareTo(o: Delayed): Int =
+    throw new IllegalStateException("should not be reached")
+
+  override def cancel(mayInterruptIfRunning: Boolean): Boolean = true
+
+  override def isCancelled: Boolean =
+    throw new IllegalStateException("should not be reached")
+
+  override def isDone: Boolean =
+    throw new IllegalStateException("should not be reached")
+
+  override def get(): Int =
+    throw new IllegalStateException("should not be reached")
+}
+
 class DAGSchedulerSuiteDummyException extends Exception
 
 class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with TimeLimits {
@@ -312,16 +337,27 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
      * Schedules shuffle merge finalize.
      */
     override private[scheduler] def scheduleShuffleMergeFinalize(
-        shuffleMapStage: ShuffleMapStage): Unit = {
-      if (shuffleMergeRegister) {
+        shuffleMapStage: ShuffleMapStage,
+        delay: Long,
+        registerMergeResults: Boolean = true): Unit = {
+      if (shuffleMergeRegister && registerMergeResults) {
         for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) {
           val mergeStatuses = Seq((part, makeMergeStatus("",
             shuffleMapStage.shuffleDep.shuffleMergeId)))
           handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
         }
-        if (shuffleMergeFinalize) {
-          handleShuffleMergeFinalized(shuffleMapStage)
-        }
+      }
+
+      shuffleMapStage.shuffleDep.getFinalizeTask match {
+        case Some(_) =>
+          assert(delay == 0 && registerMergeResults)
+        case None =>
+      }
+
+      shuffleMapStage.shuffleDep.setFinalizeTask(
+          new DummyScheduledFuture(delay, registerMergeResults))
+      if (shuffleMergeFinalize) {
+        handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId)
       }
     }
   }
@@ -472,6 +508,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
     assert(this.results === expected)
   }
 
+  /** Sends ShufflePushCompleted to the DAG scheduler. */
+  private def pushComplete(
+      shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    runEvent(ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex))
+  }
+
   test("[SPARK-3353] parent stage should have lower stage id") {
     sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
     val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
@@ -3428,6 +3470,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   private def initPushBasedShuffleConfs(conf: SparkConf) = {
     conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
     conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 1L)
     conf.set("spark.master", "pushbasedshuffleclustermanager")
     // Needed to run push-based shuffle tests in ad-hoc manner through IDE
     conf.set(Tests.IS_TESTING, true)
@@ -3439,7 +3482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
 
   test("SPARK-32920: shuffle merge finalization") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 2
     val shuffleMapRdd = new MyRDD(sc, parts, Nil)
@@ -3459,7 +3502,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   test("SPARK-32920: merger locations not empty") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 2
 
@@ -3484,7 +3527,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   test("SPARK-32920: merger locations reuse from shuffle dependency") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 2
 
@@ -3524,7 +3567,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   test("SPARK-32920: Disable shuffle merge due to not enough mergers available") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 7
 
@@ -3548,7 +3591,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   test("SPARK-32920: Ensure child stage should not start before all the" +
       " parent stages are completed with shuffle merge finalized for all the parent stages") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 1
     val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
@@ -3585,7 +3628,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
       " ShuffleDependency should not cause DAGScheduler to hang") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 20
 
@@ -3616,7 +3659,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
       " ShuffleDependency with shuffle data loss should recompute missing partitions") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val parts = 20
 
@@ -3632,7 +3675,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
 
     completeNextResultStageWithSuccess(1, 0)
 
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     val hosts = (6 to parts).map {x => s"Host$x" }
     DAGSchedulerSuite.addMergerLocs(hosts)
 
@@ -3669,7 +3712,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
   test("SPARK-32920: Merge results should be unregistered if the running stage is cancelled" +
     " before shuffle merge is finalized") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     scheduler = new MyDAGScheduler(
       sc,
@@ -3697,14 +3740,15 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == parts)
     val shuffleMapStageToCancel = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
     runEvent(StageCancelled(0, Option("Explicit cancel check")))
-    scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel)
+    scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel,
+      shuffleMapStageToCancel.shuffleDep.shuffleMergeId)
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 0)
   }
 
   test("SPARK-32920: SPARK-35549: Merge results should not get registered" +
     " after shuffle merge finalization") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
 
     scheduler = new MyDAGScheduler(
@@ -3733,7 +3777,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
     val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
     scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA",
       shuffleDep.shuffleMergeId))))
-    scheduler.handleShuffleMergeFinalized(shuffleMapStage)
+    scheduler.handleShuffleMergeFinalized(shuffleMapStage,
+      shuffleMapStage.shuffleDep.shuffleMergeId)
     scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA",
       shuffleDep.shuffleMergeId))))
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 1)
@@ -3741,7 +3786,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
 
   test("SPARK-32920: Disable push based shuffle in the case of a barrier stage") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
 
     val parts = 2
@@ -3788,7 +3833,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
 
   test("SPARK-32923: handle stage failure for indeterminate map stage with push-based shuffle") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
     val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
 
@@ -3847,11 +3892,262 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
 
     // Job successful ended.
     assert(results === Map(0 -> 11, 1 -> 12))
+  }
+
+  test("SPARK-33701: check adaptive shuffle merge finalization triggered after" +
+    " stage completion") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    // Verify finalize task is set with default delay of 10s and merge results are marked
+    // for registration
+    val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+    }.toSeq)
+    val shuffleStage2 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage2.shuffleDep.getFinalizeTask.nonEmpty)
+    val finalizeTask2 = shuffleStage2.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask2.delay == 10 && finalizeTask2.registerMergeResults)
+
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == parts)
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage2.shuffleDep.shuffleId) == parts)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
     results.clear()
     assertDataStructuresEmpty()
   }
 
-  /**
+  test("SPARK-33701: check adaptive shuffle merge finalization triggered after minimum" +
+    " threshold push complete") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
+    val parts = 4
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(0), taskResults(0)._1, taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), taskResults(0)._1, taskResults(0)._2))
+
+    val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+    // Minimum push complete for 2 tasks, should have scheduled merge finalization
+    val finalizeTask = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask.registerMergeResults && finalizeTask.delay == 0)
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(2), taskResults(0)._1, taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(3), taskResults(0)._1, taskResults(0)._2))
+
+    completeShuffleMapStageSuccessfully(1, 0, parts)
+
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  // Test the behavior of stage cancellation during the spark.shuffle.push.finalize.timeout
+  // wait for shuffle merge finalization, there are 2 different cases:
+  // 1. Deterministic stage - With deterministic stage, the shuffleMergeId = 0 for multiple
+  // stage attempts, so if the stage is cancelled before shuffle is merge finalized then
+  // the merge results are unregistered from MapOutputTracker
+  // 2. Indeterminate stage - Different attempt of the same stage can trigger shuffle merge
+  // finalization but it is validated by the shuffleMergeId (unique across stages and stage
+  // attempts for indeterminate stages) and only the shuffle merge is finalized
+  test("SPARK-33701: check adaptive shuffle merge finalization behavior with stage" +
+    " cancellation during spark.shuffle.push.finalize.timeout wait") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
+    val parts = 4
+
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+    dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
+
+    // Determinate stage
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    runEvent(StageCancelled(0, Option("Explicit cancel check")))
+    scheduler.handleShuffleMergeFinalized(shuffleStage1, shuffleStage1.shuffleDep.shuffleMergeId)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == 0)
+
+    // Indeterminate stage
+    val shuffleMapIndeterminateRdd1 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleIndeterminateDep1 = new ShuffleDependency(
+      shuffleMapIndeterminateRdd1, new HashPartitioner(parts))
+    val shuffleMapIndeterminateRdd2 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleIndeterminateDep2 = new ShuffleDependency(
+      shuffleMapIndeterminateRdd2, new HashPartitioner(parts))
+    val reduceIndeterminateRdd = new MyRDD(sc, parts, List(
+      shuffleIndeterminateDep1, shuffleIndeterminateDep2), tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceIndeterminateRdd, (0 until parts).toArray)
+
+    val indeterminateResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- indeterminateResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+
+    val shuffleIndeterminateStage = scheduler.stageIdToStage(3).asInstanceOf[ShuffleMapStage]
+    assert(shuffleIndeterminateStage.isIndeterminate)
+    scheduler.handleShuffleMergeFinalized(shuffleIndeterminateStage, 2)
+    assert(shuffleIndeterminateStage.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleIndeterminateStage.shuffleDep.shuffleMergeFinalized)
+  }
+
+  // With Adaptive shuffle merge finalization, once minimum shuffle pushes complete after stage
+  // completion, the existing shuffle merge finalization task with
+  // delay = spark.shuffle.push.finalize.timeout should be replaced with a new shuffle merge
+  // finalization task with delay = 0
+  test("SPARK-33701: check adaptive shuffle merge finalization with minimum pushes complete" +
+    " after the stage completion replacing the finalize task with delay = 0") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
+    val parts = 4
+
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+    dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
+
+    // Determinate stage
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+      asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+
+    // Minimum shuffle pushes complete, replace the finalizeTask with delay = 10
+    // with a finalizeTask with delay = 0
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+    // Existing finalizeTask with delay = 10 should be replaced with finalizeTask
+    // with delay = 0
+    val finalizeTask2 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+      asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
+  }
+
+    /**
    * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
    * Note that this checks only the host and not the executor ID.
    */
@@ -3922,7 +4218,7 @@ object DAGSchedulerSuite {
     locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
   }
 
-  def clearMergerLocs: Unit = mergerLocs.clear()
+  def clearMergerLocs(): Unit = mergerLocs.clear()
 
 }
 
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
index 298ba50..94c0417 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.shuffle
 import java.io.{File, FileNotFoundException, IOException}
 import java.net.ConnectException
 import java.nio.ByteBuffer
-import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, Semaphore}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.{ArgumentMatchers, Mock, MockitoAnnotations}
 import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
@@ -32,6 +32,8 @@ import org.mockito.invocation.InvocationOnMock
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
+import org.apache.spark.executor.CoarseGrainedExecutorBackend
+import org.apache.spark.internal.config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.server.BlockPushNonFatalFailure
 import org.apache.spark.network.server.BlockPushNonFatalFailure.ReturnCode
@@ -40,12 +42,14 @@ import org.apache.spark.network.util.TransportConf
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
 import org.apache.spark.storage._
+import org.apache.spark.util.ThreadUtils
 
 class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
 
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
   @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: BlockStoreClient = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var executorBackend: CoarseGrainedExecutorBackend = _
 
   private var conf: SparkConf = _
   private var pushedBlocks = new ArrayBuffer[String]
@@ -54,6 +58,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     super.beforeEach()
     conf = new SparkConf(loadDefaults = false)
     MockitoAnnotations.openMocks(this).close()
+    when(dependency.shuffleId).thenReturn(0)
     when(dependency.partitioner).thenReturn(new HashPartitioner(8))
     when(dependency.serializer).thenReturn(new JavaSerializer(conf))
     when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1)))
@@ -62,6 +67,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     when(mockEnv.conf).thenReturn(conf)
     when(mockEnv.blockManager).thenReturn(blockManager)
     SparkEnv.set(mockEnv)
+    when(SparkEnv.get.executorBackend).thenReturn(Some(executorBackend))
     when(blockManager.blockStoreClient).thenReturn(shuffleClient)
   }
 
@@ -91,37 +97,104 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     })
   }
 
+  private def verifyBlockPushCompleted(
+      blockPusher: ShuffleBlockPusher): Unit = {
+    verify(executorBackend, times(1))
+      .notifyDriverAboutPushCompletion(dependency.shuffleId, 0, 0)
+    assert(blockPusher.isPushCompletionNotified)
+  }
+
   test("A batch of blocks is limited by maxBlocksBatchSize") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
     conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
     val largeBlockSize = 2 * 1024 * 1024
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
       mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 3)
+    verifyBlockPushCompleted(blockPusher)
     verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
   }
 
   test("Large blocks are excluded in the preparation") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 2)
     verifyPushRequests(pushRequests, Seq(6, 1024))
+    verifyBlockPushCompleted(blockPusher)
   }
 
   test("Number of blocks in a push request are limited by maxBlocksInFlightPerAddress ") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 5)
     verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
+    verifyBlockPushCompleted(blockPusher)
+  }
+
+  test("SPARK-33701: Ensure all the blocks are pushed before notifying driver" +
+    " about push completion") {
+    conf.set(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS, 12)
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "20b")
+    val latch = new CountDownLatch(1)
+    // Different remote servers to send 2 different requests to ensure that all the blocks
+    // are pushed before notifying driver about push completion
+    when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1),
+      BlockManagerId("slow-client", "slow-client", 1)))
+    when(shuffleClient.pushBlocks(ArgumentMatchers.eq("slow-client"), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockPushListener = invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+        latch.await()
+        // Add a small wait here to delay the "onBlockPushSuccess" to mimic the real world
+        Thread.sleep(500)
+        blocks.foreach { blockId =>
+          blockPushListener.onBlockPushSuccess(blockId, mock(classOf[ManagedBuffer]))
+        }
+      })
+    when(shuffleClient.pushBlocks(ArgumentMatchers.eq("test-client"), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockPushListener = invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+        latch.await()
+        blocks.foreach { blockId =>
+          blockPushListener.onBlockPushSuccess(blockId, mock(classOf[ManagedBuffer]))
+        }
+      })
+    val semaphore = new Semaphore(0)
+    val blockPusher = new ConcurrentTestBlockPusher(conf, semaphore)
+    val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
+      mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
+    latch.countDown()
+    latch.countDown()
+    semaphore.acquire()
+    assert(blockPusher.bytesInFlight <= 0)
+    assert(pushRequests.length == 2)
+    verifyPushRequests(pushRequests, Seq(6, 4))
+    verifyBlockPushCompleted(blockPusher)
   }
 
   test("Basic block push") {
@@ -133,6 +206,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(1))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(blockPusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -146,6 +220,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(1))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -159,6 +234,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(8))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -199,6 +275,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(4))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == 8)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -213,6 +290,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(4))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -279,6 +357,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     verify(shuffleClient, times(8))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == 7)
+    verifyBlockPushCompleted(pusher)
   }
 
   test("More blocks are not pushed when a block push fails with too late " +
@@ -333,6 +412,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     // 2 blocks for each merger locations
     assert(pushedBlocks.length == 4)
     assert(pusher.unreachableBlockMgrs.size == 2)
+    verifyBlockPushCompleted(pusher)
   }
 
   test("SPARK-36255: FileNotFoundException stops the push") {
@@ -359,7 +439,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
     ShuffleBlockPusher.stop()
   }
 
-  private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) {
+  private class TestShuffleBlockPusher(
+      conf: SparkConf) extends ShuffleBlockPusher(conf) {
     val tasks = new LinkedBlockingQueue[Runnable]
 
     override protected def submitTask(task: Runnable): Unit = {
@@ -385,4 +466,18 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
       managedBuffer
     }
   }
+
+  private class ConcurrentTestBlockPusher(conf: SparkConf, semaphore: Semaphore)
+      extends TestShuffleBlockPusher(conf) {
+    val blockPusher = ThreadUtils.newDaemonFixedThreadPool(1, "test-block-pusher")
+
+    override protected def submitTask(task: Runnable): Unit = {
+      blockPusher.execute(task)
+    }
+
+    override def notifyDriverAboutPushCompletion(): Unit = {
+      super.notifyDriverAboutPushCompletion()
+      semaphore.release()
+    }
+  }
 }
diff --git a/docs/configuration.md b/docs/configuration.md
index 2d4164f..80f17a8 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -3268,4 +3268,20 @@ Push-based shuffle helps improve the reliability and performance of spark shuffl
   </td>
   <td>3.2.0</td>
 </tr>
+<tr>
+  <td><code>spark.shuffle.push.minShuffleSizeToWait</code></td>
+  <td><code>500m</code></td>
+  <td>
+    Driver will wait for merge finalization to complete only if total shuffle data size is more than this threshold. If total shuffle size is less, driver will immediately finalize the shuffle output.
+  </td>
+  <td>3.3.0</td>
+</tr>
+<tr>
+  <td><code>spark.shuffle.push.minCompletedPushRatio</code></td>
+  <td><code>1.0</code></td>
+  <td>
+    Fraction of minimum map partitions that should be push complete before driver starts shuffle merge finalization during push based shuffle.
+  </td>
+  <td>3.3.0</td>
+</tr>
 </table>

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org