You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2020/03/17 07:24:55 UTC

[spark] branch branch-3.0 updated: [SPARK-31134][SQL] optimize skew join after shuffle partitions are coalesced

This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 0512b3f  [SPARK-31134][SQL] optimize skew join after shuffle partitions are coalesced
0512b3f is described below

commit 0512b3f427274c8bda249fba02cd16f5694a4ea5
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Mar 17 00:23:16 2020 -0700

    [SPARK-31134][SQL] optimize skew join after shuffle partitions are coalesced
    
    ### What changes were proposed in this pull request?
    
    Run the `OptimizeSkewedJoin` rule after the `CoalesceShufflePartitions` rule.
    
    ### Why are the changes needed?
    
    Remove duplicated coalescing code in `OptimizeSkewedJoin`.
    
    ### Does this PR introduce any user-facing change?
    
    No
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #27893 from cloud-fan/aqe.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: gatorsmile <ga...@gmail.com>
    (cherry picked from commit 30d95356f1881c32eb39e51525d2bcb331fcf867)
    Signed-off-by: gatorsmile <ga...@gmail.com>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |   9 +-
 .../adaptive/CoalesceShufflePartitions.scala       |   2 -
 .../execution/adaptive/OptimizeSkewedJoin.scala    | 272 ++++++++++-----------
 .../execution/adaptive/ShufflePartitionsUtil.scala |  18 +-
 .../sql/execution/ShufflePartitionsUtilSuite.scala |   2 -
 5 files changed, 146 insertions(+), 157 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 68da06d..b54a32f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -96,13 +96,10 @@ case class AdaptiveSparkPlanExec(
   // optimizations should be stage-independent.
   @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
     ReuseAdaptiveSubquery(conf, context.subqueryCache),
-    // Here the 'OptimizeSkewedJoin' rule should be executed
-    // before 'CoalesceShufflePartitions', as the skewed partition handled
-    // in 'OptimizeSkewedJoin' rule, should be omitted in 'CoalesceShufflePartitions'.
-    OptimizeSkewedJoin(conf),
     CoalesceShufflePartitions(context.session),
-    // The rule of 'OptimizeLocalShuffleReader' need to make use of the 'partitionStartIndices'
-    // in 'CoalesceShufflePartitions' rule. So it must be after 'CoalesceShufflePartitions' rule.
+    // The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
+    // added by `CoalesceShufflePartitions`. So they must be executed after it.
+    OptimizeSkewedJoin(conf),
     OptimizeLocalShuffleReader(conf),
     ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
     CollapseCodegenStages(conf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
index d2a7f6a..226d692 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
@@ -74,8 +74,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
           .getOrElse(session.sparkContext.defaultParallelism)
         val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
           validMetrics.toArray,
-          firstPartitionIndex = 0,
-          lastPartitionIndex = distinctNumPreShufflePartitions.head,
           advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
           minNumPartitions = minPartitionNum)
         // This transformation adds new nodes, so we must use `transformUp` here.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index db65af6..e02b9af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.commons.io.FileUtils
 
-import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkContext, SparkEnv}
+import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
@@ -83,14 +83,14 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
   /**
    * The goal of skew join optimization is to make the data distribution more even. The target size
    * to split skewed partitions is the average size of non-skewed partition, or the
-   * target post-shuffle partition size if avg size is smaller than it.
+   * advisory partition size if avg size is smaller than it.
    */
-  private def targetSize(stats: MapOutputStatistics, medianSize: Long): Long = {
-    val targetPostShuffleSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
-    val nonSkewSizes = stats.bytesByPartitionId.filterNot(isSkewed(_, medianSize))
+  private def targetSize(sizes: Seq[Long], medianSize: Long): Long = {
+    val advisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
+    val nonSkewSizes = sizes.filterNot(isSkewed(_, medianSize))
     // It's impossible that all the partitions are skewed, as we use median size to define skew.
     assert(nonSkewSizes.nonEmpty)
-    math.max(targetPostShuffleSize, nonSkewSizes.sum / nonSkewSizes.length)
+    math.max(advisorySize, nonSkewSizes.sum / nonSkewSizes.length)
   }
 
   /**
@@ -102,21 +102,29 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
   }
 
   /**
-   * Split the skewed partition based on the map size and the max split number.
+   * Splits the skewed partition based on the map size and the target partition size
+   * after split, and create a list of `PartialMapperPartitionSpec`. Returns None if can't split.
    */
-  private def getMapStartIndices(
-      stage: ShuffleQueryStageExec,
-      partitionId: Int,
-      targetSize: Long): Seq[Int] = {
-    val shuffleId = stage.shuffle.shuffleDependency.shuffleHandle.shuffleId
-    val mapPartitionSizes = getMapSizesForReduceId(shuffleId, partitionId)
-    ShufflePartitionsUtil.splitSizeListByTargetSize(mapPartitionSizes, targetSize)
-  }
-
-  private def getStatistics(stage: ShuffleQueryStageExec): MapOutputStatistics = {
-    assert(stage.resultOption.isDefined, "ShuffleQueryStageExec should" +
-      " already be ready when executing OptimizeSkewedPartitions rule")
-    stage.resultOption.get.asInstanceOf[MapOutputStatistics]
+  private def createSkewPartitionSpecs(
+      shuffleId: Int,
+      reducerId: Int,
+      targetSize: Long): Option[Seq[PartialReducerPartitionSpec]] = {
+    val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
+    val mapStartIndices = ShufflePartitionsUtil.splitSizeListByTargetSize(
+      mapPartitionSizes, targetSize)
+    if (mapStartIndices.length > 1) {
+      Some(mapStartIndices.indices.map { i =>
+        val startMapIndex = mapStartIndices(i)
+        val endMapIndex = if (i == mapStartIndices.length - 1) {
+          mapPartitionSizes.length
+        } else {
+          mapStartIndices(i + 1)
+        }
+        PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex)
+      })
+    } else {
+      None
+    }
   }
 
   private def canSplitLeftSide(joinType: JoinType) = {
@@ -128,12 +136,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
     joinType == Inner || joinType == Cross || joinType == RightOuter
   }
 
-  private def getNumMappers(stage: ShuffleQueryStageExec): Int = {
-    stage.shuffle.shuffleDependency.rdd.partitions.length
-  }
-
-  private def getSizeInfo(medianSize: Long, maxSize: Long): String = {
-    s"median size: $medianSize, max size: ${maxSize}"
+  private def getSizeInfo(medianSize: Long, sizes: Seq[Long]): String = {
+    s"median size: $medianSize, max size: ${sizes.max}, min size: ${sizes.min}, avg size: " +
+      sizes.sum / sizes.length
   }
 
   /*
@@ -150,101 +155,90 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
    */
   def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
     case smj @ SortMergeJoinExec(_, _, joinType, _,
-        s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
-        s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _)
+        s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
+        s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
         if supportedJoinTypes.contains(joinType) =>
-      val leftStats = getStatistics(left)
-      val rightStats = getStatistics(right)
-      val numPartitions = leftStats.bytesByPartitionId.length
-
-      val leftMedSize = medianSize(leftStats)
-      val rightMedSize = medianSize(rightStats)
+      assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
+      val numPartitions = left.partitionsWithSizes.length
+      // We use the median size of the original shuffle partitions to detect skewed partitions.
+      val leftMedSize = medianSize(left.mapStats)
+      val rightMedSize = medianSize(right.mapStats)
       logDebug(
         s"""
-          |Try to optimize skewed join.
-          |Left side partition size:
-          |${getSizeInfo(leftMedSize, leftStats.bytesByPartitionId.max)}
-          |Right side partition size:
-          |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)}
+          |Optimizing skewed join.
+          |Left side partitions size info:
+          |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
+          |Right side partitions size info:
+          |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
         """.stripMargin)
       val canSplitLeft = canSplitLeftSide(joinType)
       val canSplitRight = canSplitRightSide(joinType)
-      val leftTargetSize = targetSize(leftStats, leftMedSize)
-      val rightTargetSize = targetSize(rightStats, rightMedSize)
+      // We use the actual partition sizes (may be coalesced) to calculate target size, so that
+      // the final data distribution is even (coalesced partitions + split partitions).
+      val leftActualSizes = left.partitionsWithSizes.map(_._2)
+      val rightActualSizes = right.partitionsWithSizes.map(_._2)
+      val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
+      val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
 
       val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
       val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      // This is used to delay the creation of non-skew partitions so that we can potentially
-      // coalesce them like `CoalesceShufflePartitions` does.
-      val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int]
       val leftSkewDesc = new SkewDesc
       val rightSkewDesc = new SkewDesc
       for (partitionIndex <- 0 until numPartitions) {
-        val leftSize = leftStats.bytesByPartitionId(partitionIndex)
-        val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft
-        val rightSize = rightStats.bytesByPartitionId(partitionIndex)
-        val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight
-        if (isLeftSkew || isRightSkew) {
-          if (nonSkewPartitionIndices.nonEmpty) {
-            // As soon as we see a skew, we'll "flush" out unhandled non-skew partitions.
-            createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p =>
-              leftSidePartitions += p
-              rightSidePartitions += p
-            }
-            nonSkewPartitionIndices.clear()
-          }
-
-          val leftParts = if (isLeftSkew) {
-            val mapStartIndices = getMapStartIndices(left, partitionIndex, leftTargetSize)
-            if (mapStartIndices.length > 1) {
-              leftSkewDesc.addPartitionSize(leftSize)
-              createSkewPartitions(partitionIndex, mapStartIndices, getNumMappers(left))
-            } else {
-              Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
-            }
-          } else {
-            Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
-          }
-
-          val rightParts = if (isRightSkew) {
-            val mapStartIndices = getMapStartIndices(right, partitionIndex, rightTargetSize)
-            if (mapStartIndices.length > 1) {
-              rightSkewDesc.addPartitionSize(rightSize)
-              createSkewPartitions(partitionIndex, mapStartIndices, getNumMappers(right))
-            } else {
-              Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
-            }
-          } else {
-            Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
+        val isLeftSkew = isSkewed(leftActualSizes(partitionIndex), leftMedSize) && canSplitLeft
+        val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
+        val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
+
+        val isRightSkew = isSkewed(rightActualSizes(partitionIndex), rightMedSize) && canSplitRight
+        val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
+        val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
+
+        // A skewed partition should never be coalesced, but skip it here just to be safe.
+        val leftParts = if (isLeftSkew && !isLeftCoalesced) {
+          val reducerId = leftPartSpec.startReducerIndex
+          val skewSpecs = createSkewPartitionSpecs(
+            left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
+          if (skewSpecs.isDefined) {
+            logDebug(s"Left side partition $partitionIndex is skewed, split it into " +
+              s"${skewSpecs.get.length} parts.")
+            leftSkewDesc.addPartitionSize(leftActualSizes(partitionIndex))
           }
+          skewSpecs.getOrElse(Seq(leftPartSpec))
+        } else {
+          Seq(leftPartSpec)
+        }
 
-          for {
-            leftSidePartition <- leftParts
-            rightSidePartition <- rightParts
-          } {
-            leftSidePartitions += leftSidePartition
-            rightSidePartitions += rightSidePartition
+        // A skewed partition should never be coalesced, but skip it here just to be safe.
+        val rightParts = if (isRightSkew && !isRightCoalesced) {
+          val reducerId = rightPartSpec.startReducerIndex
+          val skewSpecs = createSkewPartitionSpecs(
+            right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
+          if (skewSpecs.isDefined) {
+            logDebug(s"Right side partition $partitionIndex is skewed, split it into " +
+              s"${skewSpecs.get.length} parts.")
+            rightSkewDesc.addPartitionSize(rightActualSizes(partitionIndex))
           }
+          skewSpecs.getOrElse(Seq(rightPartSpec))
         } else {
-          // Add to `nonSkewPartitionIndices` first, and add real partitions later, in case we can
-          // coalesce the non-skew partitions.
-          nonSkewPartitionIndices += partitionIndex
-          // If this is the last partition, add real partition immediately.
-          if (partitionIndex == numPartitions - 1) {
-            createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p =>
-              leftSidePartitions += p
-              rightSidePartitions += p
-            }
-            nonSkewPartitionIndices.clear()
-          }
+          Seq(rightPartSpec)
+        }
+
+        for {
+          leftSidePartition <- leftParts
+          rightSidePartition <- rightParts
+        } {
+          leftSidePartitions += leftSidePartition
+          rightSidePartitions += rightSidePartition
         }
       }
 
       logDebug("number of skewed partitions: " +
         s"left ${leftSkewDesc.numPartitions}, right ${rightSkewDesc.numPartitions}")
       if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) {
-        val newLeft = CustomShuffleReaderExec(left, leftSidePartitions, leftSkewDesc.toString)
-        val newRight = CustomShuffleReaderExec(right, rightSidePartitions, rightSkewDesc.toString)
+        val newLeft = CustomShuffleReaderExec(
+          left.shuffleStage, leftSidePartitions, leftSkewDesc.toString)
+        val newRight = CustomShuffleReaderExec(
+          right.shuffleStage, rightSidePartitions, rightSkewDesc.toString)
         smj.copy(
           left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
       } else {
@@ -252,44 +246,6 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
       }
   }
 
-  private def createNonSkewPartitions(
-      leftStats: MapOutputStatistics,
-      rightStats: MapOutputStatistics,
-      nonSkewPartitionIndices: Seq[Int]): Seq[ShufflePartitionSpec] = {
-    assert(nonSkewPartitionIndices.nonEmpty)
-    val shouldCoalesce = conf.getConf(SQLConf.COALESCE_PARTITIONS_ENABLED)
-    if (!shouldCoalesce || nonSkewPartitionIndices.length == 1) {
-      nonSkewPartitionIndices.map(i => CoalescedPartitionSpec(i, i + 1))
-    } else {
-      // We fall back to Spark default parallelism if the minimum number of coalesced partitions
-      // is not set, so to avoid perf regressions compared to no coalescing.
-      val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
-        .getOrElse(SparkContext.getActive.get.defaultParallelism)
-      ShufflePartitionsUtil.coalescePartitions(
-        Array(leftStats, rightStats),
-        firstPartitionIndex = nonSkewPartitionIndices.head,
-        // `lastPartitionIndex` is exclusive.
-        lastPartitionIndex = nonSkewPartitionIndices.last + 1,
-        advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
-        minNumPartitions = minPartitionNum)
-    }
-  }
-
-  private def createSkewPartitions(
-      reducerIndex: Int,
-      mapStartIndices: Seq[Int],
-      numMappers: Int): Seq[PartialReducerPartitionSpec] = {
-    mapStartIndices.indices.map { i =>
-      val startMapIndex = mapStartIndices(i)
-      val endMapIndex = if (i == mapStartIndices.length - 1) {
-        numMappers
-      } else {
-        mapStartIndices(i + 1)
-      }
-      PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex)
-    }
-  }
-
   override def apply(plan: SparkPlan): SparkPlan = {
     if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) {
       return plan
@@ -328,6 +284,48 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
   }
 }
 
+private object ShuffleStage {
+  def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
+    case s: ShuffleQueryStageExec =>
+      val mapStats = getMapStats(s)
+      val sizes = mapStats.bytesByPartitionId
+      val partitions = sizes.zipWithIndex.map {
+        case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
+      }
+      Some(ShuffleStageInfo(s, mapStats, partitions))
+
+    case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) =>
+      val mapStats = getMapStats(s)
+      val sizes = mapStats.bytesByPartitionId
+      val partitions = partitionSpecs.map {
+        case spec @ CoalescedPartitionSpec(start, end) =>
+          var sum = 0L
+          var i = start
+          while (i < end) {
+            sum += sizes(i)
+            i += 1
+          }
+          spec -> sum
+        case other => throw new IllegalArgumentException(
+          s"Expect CoalescedPartitionSpec but got $other")
+      }
+      Some(ShuffleStageInfo(s, mapStats, partitions))
+
+    case _ => None
+  }
+
+  private def getMapStats(stage: ShuffleQueryStageExec): MapOutputStatistics = {
+    assert(stage.resultOption.isDefined, "ShuffleQueryStageExec should" +
+      " already be ready when executing OptimizeSkewedPartitions rule")
+    stage.resultOption.get.asInstanceOf[MapOutputStatistics]
+  }
+}
+
+private case class ShuffleStageInfo(
+    shuffleStage: ShuffleQueryStageExec,
+    mapStats: MapOutputStatistics,
+    partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])
+
 private class SkewDesc {
   private[this] var numSkewedPartitions: Int = 0
   private[this] var totalSize: Long = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
index 292df11..208cc05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
@@ -28,10 +28,9 @@ object ShufflePartitionsUtil extends Logging {
   final val MERGED_PARTITION_FACTOR = 1.2
 
   /**
-   * Coalesce the same range of partitions (`firstPartitionIndex` to `lastPartitionIndex`, the
-   * start is inclusive and the end is exclusive) from multiple shuffles. This method assumes that
-   * all the shuffles have the same number of partitions, and the partitions of same index will be
-   * read together by one task.
+   * Coalesce the partitions from multiple shuffles. This method assumes that all the shuffles
+   * have the same number of partitions, and the partitions of same index will be read together
+   * by one task.
    *
    * The strategy used to determine the number of coalesced partitions is described as follows.
    * To determine the number of coalesced partitions, we have a target size for a coalesced
@@ -56,8 +55,6 @@ object ShufflePartitionsUtil extends Logging {
    */
   def coalescePartitions(
       mapOutputStatistics: Array[MapOutputStatistics],
-      firstPartitionIndex: Int,
-      lastPartitionIndex: Int,
       advisoryTargetSize: Long,
       minNumPartitions: Int): Seq[ShufflePartitionSpec] = {
     // If `minNumPartitions` is very large, it is possible that we need to use a value less than
@@ -87,11 +84,12 @@ object ShufflePartitionsUtil extends Logging {
       "There should be only one distinct value of the number of shuffle partitions " +
         "among registered Exchange operators.")
 
+    val numPartitions = distinctNumShufflePartitions.head
     val partitionSpecs = ArrayBuffer[CoalescedPartitionSpec]()
-    var latestSplitPoint = firstPartitionIndex
+    var latestSplitPoint = 0
     var coalescedSize = 0L
-    var i = firstPartitionIndex
-    while (i < lastPartitionIndex) {
+    var i = 0
+    while (i < numPartitions) {
       // We calculate the total size of i-th shuffle partitions from all shuffles.
       var totalSizeOfCurrentPartition = 0L
       var j = 0
@@ -112,7 +110,7 @@ object ShufflePartitionsUtil extends Logging {
       }
       i += 1
     }
-    partitionSpecs += CoalescedPartitionSpec(latestSplitPoint, lastPartitionIndex)
+    partitionSpecs += CoalescedPartitionSpec(latestSplitPoint, numPartitions)
 
     partitionSpecs
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
index 2cd4c98..7acc33c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
@@ -33,8 +33,6 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite {
     }
     val estimatedPartitionStartIndices = ShufflePartitionsUtil.coalescePartitions(
       mapOutputStatistics,
-      0,
-      bytesByPartitionIdArray.head.length,
       targetSize,
       minNumPartitions)
     assert(estimatedPartitionStartIndices === expectedPartitionStartIndices)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org