You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/26 06:21:33 UTC

[GitHub] [spark] maropu commented on a change in pull request #32328: [SPARK-35214][SQL] OptimizeSkewedJoin support ShuffledHashJoinExec

maropu commented on a change in pull request #32328:
URL: https://github.com/apache/spark/pull/32328#discussion_r620000087



##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -19,15 +19,23 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution}
 
 /**
  * Holds common logic for join operators by shuffling two child relations
  * using the join keys.
  */
 trait ShuffledJoin extends BaseJoinExec {
+  def isSkewJoin: Boolean

Review comment:
       Please update the description in SQLConf? https://github.com/apache/spark/blob/38ef4771d447f6135382ee2767b3f32b96cb1b0e/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L541-L548

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
##########
@@ -157,98 +157,121 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
    * 4. Wrap the join right child with a special shuffle reader that reads partition0 3 times by
    *    3 tasks separately.
    */
-  def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case smj @ SortMergeJoinExec(_, _, joinType, _,
-        s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
-        s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
-        if supportedJoinTypes.contains(joinType) =>
-      assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
-      val numPartitions = left.partitionsWithSizes.length
-      // We use the median size of the original shuffle partitions to detect skewed partitions.
-      val leftMedSize = medianSize(left.mapStats)
-      val rightMedSize = medianSize(right.mapStats)
-      logDebug(
-        s"""
-          |Optimizing skewed join.
-          |Left side partitions size info:
-          |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
-          |Right side partitions size info:
-          |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
-        """.stripMargin)
-      val canSplitLeft = canSplitLeftSide(joinType)
-      val canSplitRight = canSplitRightSide(joinType)
-      // We use the actual partition sizes (may be coalesced) to calculate target size, so that
-      // the final data distribution is even (coalesced partitions + split partitions).
-      val leftActualSizes = left.partitionsWithSizes.map(_._2)
-      val rightActualSizes = right.partitionsWithSizes.map(_._2)
-      val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
-      val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
-
-      val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      var numSkewedLeft = 0
-      var numSkewedRight = 0
-      for (partitionIndex <- 0 until numPartitions) {
-        val leftActualSize = leftActualSizes(partitionIndex)
-        val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
-        val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
-        val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
-
-        val rightActualSize = rightActualSizes(partitionIndex)
-        val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
-        val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
-        val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
-
-        // A skewed partition should never be coalesced, but skip it here just to be safe.
-        val leftParts = if (isLeftSkew && !isLeftCoalesced) {
-          val reducerId = leftPartSpec.startReducerIndex
-          val skewSpecs = createSkewPartitionSpecs(
-            left.mapStats.shuffleId, reducerId, leftTargetSize)
-          if (skewSpecs.isDefined) {
-            logDebug(s"Left side partition $partitionIndex " +
-              s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
-              s"split it into ${skewSpecs.get.length} parts.")
-            numSkewedLeft += 1
-          }
-          skewSpecs.getOrElse(Seq(leftPartSpec))
-        } else {
-          Seq(leftPartSpec)
+  private def getOptimizedChildren(

Review comment:
       nit: How about `getOptimizedChildren` -> `tryToOptimizedChildren`?

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
##########
@@ -685,64 +691,82 @@ class AdaptiveQueryExecSuite
   }
 
   test("SPARK-29544: adaptive skew join with different join types") {
-    withSQLConf(
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
-      SQLConf.SHUFFLE_PARTITIONS.key -> "100",
-      SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
-      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
-      withTempView("skewData1", "skewData2") {
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .when('id >= 750, 1000)
-              .otherwise('id).as("key1"),
-            'id as "value1")
-          .createOrReplaceTempView("skewData1")
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .otherwise('id).as("key2"),
-            'id as "value2")
-          .createOrReplaceTempView("skewData2")
+    Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
+      val isSMJ = joinHint == "SHUFFLE_MERGE"
+      withSQLConf(
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+        SQLConf.SHUFFLE_PARTITIONS.key -> "100",
+        SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
+        SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
+        withTempView("skewData1", "skewData2") {
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .when('id >= 750, 1000)
+                .otherwise('id).as("key1"),
+              'id as "value1")
+            .createOrReplaceTempView("skewData1")
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .otherwise('id).as("key2"),
+              'id as "value2")
+            .createOrReplaceTempView("skewData2")
 
-        def checkSkewJoin(
-            joins: Seq[SortMergeJoinExec],
-            leftSkewNum: Int,
-            rightSkewNum: Int): Unit = {
-          assert(joins.size == 1 && joins.head.isSkewJoin)
-          assert(joins.head.left.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == leftSkewNum)
-          assert(joins.head.right.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == rightSkewNum)
-        }
+          def checkSkewJoin(
+              joins: Seq[ShuffledJoin],
+              leftSkewNum: Int,
+              rightSkewNum: Int): Unit = {
+            assert(joins.size == 1 && joins.head.isSkewJoin)
+            assert(joins.head.left.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == leftSkewNum)
+            assert(joins.head.right.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == rightSkewNum)
+          }
 
-        // skewed inner join optimization
-        val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
-        val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
-        checkSkewJoin(innerSmj, 2, 1)
-
-        // skewed left outer join optimization
-        val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
-        val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
-        checkSkewJoin(leftSmj, 2, 0)
-
-        // skewed right outer join optimization
-        val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
-        val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
-        checkSkewJoin(rightSmj, 0, 1)
+          // skewed inner join optimization
+          val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+              s"JOIN skewData2 ON key1 = key2")
+          val inner = if (isSMJ) {
+            findTopLevelSortMergeJoin(innerAdaptivePlan)
+          } else {
+            findTopLevelShuffledHashJoin(innerAdaptivePlan)
+          }
+          checkSkewJoin(inner, 2, 1)
+
+          // skewed left outer join optimization
+          val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData2) */ * FROM skewData1 " +
+              s"LEFT OUTER JOIN skewData2 ON key1 = key2")
+          val leftSmj = if (isSMJ) {
+            findTopLevelSortMergeJoin(leftAdaptivePlan)
+          } else {
+            findTopLevelShuffledHashJoin(leftAdaptivePlan)
+          }
+          checkSkewJoin(leftSmj, 2, 0)
+
+          // skewed right outer join optimization
+          val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+              s"RIGHT OUTER JOIN skewData2 ON key1 = key2")

Review comment:
       ditto

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
##########
@@ -685,64 +691,82 @@ class AdaptiveQueryExecSuite
   }
 
   test("SPARK-29544: adaptive skew join with different join types") {
-    withSQLConf(
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
-      SQLConf.SHUFFLE_PARTITIONS.key -> "100",
-      SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
-      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
-      withTempView("skewData1", "skewData2") {
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .when('id >= 750, 1000)
-              .otherwise('id).as("key1"),
-            'id as "value1")
-          .createOrReplaceTempView("skewData1")
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .otherwise('id).as("key2"),
-            'id as "value2")
-          .createOrReplaceTempView("skewData2")
+    Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
+      val isSMJ = joinHint == "SHUFFLE_MERGE"
+      withSQLConf(
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+        SQLConf.SHUFFLE_PARTITIONS.key -> "100",
+        SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
+        SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
+        withTempView("skewData1", "skewData2") {
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .when('id >= 750, 1000)
+                .otherwise('id).as("key1"),
+              'id as "value1")
+            .createOrReplaceTempView("skewData1")
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .otherwise('id).as("key2"),
+              'id as "value2")
+            .createOrReplaceTempView("skewData2")
 
-        def checkSkewJoin(
-            joins: Seq[SortMergeJoinExec],
-            leftSkewNum: Int,
-            rightSkewNum: Int): Unit = {
-          assert(joins.size == 1 && joins.head.isSkewJoin)
-          assert(joins.head.left.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == leftSkewNum)
-          assert(joins.head.right.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == rightSkewNum)
-        }
+          def checkSkewJoin(
+              joins: Seq[ShuffledJoin],
+              leftSkewNum: Int,
+              rightSkewNum: Int): Unit = {
+            assert(joins.size == 1 && joins.head.isSkewJoin)
+            assert(joins.head.left.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == leftSkewNum)
+            assert(joins.head.right.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == rightSkewNum)
+          }
 
-        // skewed inner join optimization
-        val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
-        val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
-        checkSkewJoin(innerSmj, 2, 1)
-
-        // skewed left outer join optimization
-        val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
-        val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
-        checkSkewJoin(leftSmj, 2, 0)
-
-        // skewed right outer join optimization
-        val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
-        val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
-        checkSkewJoin(rightSmj, 0, 1)
+          // skewed inner join optimization
+          val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+              s"JOIN skewData2 ON key1 = key2")

Review comment:
       nit: remove `s`.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
##########
@@ -157,98 +157,121 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
    * 4. Wrap the join right child with a special shuffle reader that reads partition0 3 times by
    *    3 tasks separately.
    */
-  def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case smj @ SortMergeJoinExec(_, _, joinType, _,
-        s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
-        s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
-        if supportedJoinTypes.contains(joinType) =>
-      assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
-      val numPartitions = left.partitionsWithSizes.length
-      // We use the median size of the original shuffle partitions to detect skewed partitions.
-      val leftMedSize = medianSize(left.mapStats)
-      val rightMedSize = medianSize(right.mapStats)
-      logDebug(
-        s"""
-          |Optimizing skewed join.
-          |Left side partitions size info:
-          |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
-          |Right side partitions size info:
-          |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
-        """.stripMargin)
-      val canSplitLeft = canSplitLeftSide(joinType)
-      val canSplitRight = canSplitRightSide(joinType)
-      // We use the actual partition sizes (may be coalesced) to calculate target size, so that
-      // the final data distribution is even (coalesced partitions + split partitions).
-      val leftActualSizes = left.partitionsWithSizes.map(_._2)
-      val rightActualSizes = right.partitionsWithSizes.map(_._2)
-      val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
-      val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
-
-      val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      var numSkewedLeft = 0
-      var numSkewedRight = 0
-      for (partitionIndex <- 0 until numPartitions) {
-        val leftActualSize = leftActualSizes(partitionIndex)
-        val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
-        val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
-        val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
-
-        val rightActualSize = rightActualSizes(partitionIndex)
-        val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
-        val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
-        val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
-
-        // A skewed partition should never be coalesced, but skip it here just to be safe.
-        val leftParts = if (isLeftSkew && !isLeftCoalesced) {
-          val reducerId = leftPartSpec.startReducerIndex
-          val skewSpecs = createSkewPartitionSpecs(
-            left.mapStats.shuffleId, reducerId, leftTargetSize)
-          if (skewSpecs.isDefined) {
-            logDebug(s"Left side partition $partitionIndex " +
-              s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
-              s"split it into ${skewSpecs.get.length} parts.")
-            numSkewedLeft += 1
-          }
-          skewSpecs.getOrElse(Seq(leftPartSpec))
-        } else {
-          Seq(leftPartSpec)
+  private def getOptimizedChildren(
+      left: ShuffleStageInfo,
+      right: ShuffleStageInfo,
+      joinType: JoinType): Option[(SparkPlan, SparkPlan)] = {
+    assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
+    val numPartitions = left.partitionsWithSizes.length
+    // We use the median size of the original shuffle partitions to detect skewed partitions.
+    val leftMedSize = medianSize(left.mapStats)
+    val rightMedSize = medianSize(right.mapStats)
+    logDebug(
+      s"""
+         |Optimizing skewed join.
+         |Left side partitions size info:
+         |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
+         |Right side partitions size info:
+         |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
+      """.stripMargin)
+    val canSplitLeft = canSplitLeftSide(joinType)
+    val canSplitRight = canSplitRightSide(joinType)
+    // We use the actual partition sizes (may be coalesced) to calculate target size, so that
+    // the final data distribution is even (coalesced partitions + split partitions).
+    val leftActualSizes = left.partitionsWithSizes.map(_._2)
+    val rightActualSizes = right.partitionsWithSizes.map(_._2)
+    val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
+    val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
+
+    val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+    val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+    var numSkewedLeft = 0
+    var numSkewedRight = 0
+    for (partitionIndex <- 0 until numPartitions) {
+      val leftActualSize = leftActualSizes(partitionIndex)
+      val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
+      val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
+      val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
+
+      val rightActualSize = rightActualSizes(partitionIndex)
+      val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
+      val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
+      val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
+
+      // A skewed partition should never be coalesced, but skip it here just to be safe.
+      val leftParts = if (isLeftSkew && !isLeftCoalesced) {
+        val reducerId = leftPartSpec.startReducerIndex
+        val skewSpecs = createSkewPartitionSpecs(
+          left.mapStats.shuffleId, reducerId, leftTargetSize)
+        if (skewSpecs.isDefined) {
+          logDebug(s"Left side partition $partitionIndex " +
+            s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
+            s"split it into ${skewSpecs.get.length} parts.")
+          numSkewedLeft += 1
         }
+        skewSpecs.getOrElse(Seq(leftPartSpec))
+      } else {
+        Seq(leftPartSpec)
+      }
 
-        // A skewed partition should never be coalesced, but skip it here just to be safe.
-        val rightParts = if (isRightSkew && !isRightCoalesced) {
-          val reducerId = rightPartSpec.startReducerIndex
-          val skewSpecs = createSkewPartitionSpecs(
-            right.mapStats.shuffleId, reducerId, rightTargetSize)
-          if (skewSpecs.isDefined) {
-            logDebug(s"Right side partition $partitionIndex " +
-              s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
-              s"split it into ${skewSpecs.get.length} parts.")
-            numSkewedRight += 1
-          }
-          skewSpecs.getOrElse(Seq(rightPartSpec))
-        } else {
-          Seq(rightPartSpec)
+      // A skewed partition should never be coalesced, but skip it here just to be safe.
+      val rightParts = if (isRightSkew && !isRightCoalesced) {
+        val reducerId = rightPartSpec.startReducerIndex
+        val skewSpecs = createSkewPartitionSpecs(
+          right.mapStats.shuffleId, reducerId, rightTargetSize)
+        if (skewSpecs.isDefined) {
+          logDebug(s"Right side partition $partitionIndex " +
+            s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
+            s"split it into ${skewSpecs.get.length} parts.")
+          numSkewedRight += 1
         }
+        skewSpecs.getOrElse(Seq(rightPartSpec))
+      } else {
+        Seq(rightPartSpec)
+      }
 
-        for {
-          leftSidePartition <- leftParts
-          rightSidePartition <- rightParts
-        } {
-          leftSidePartitions += leftSidePartition
-          rightSidePartitions += rightSidePartition
-        }
+      for {
+        leftSidePartition <- leftParts
+        rightSidePartition <- rightParts
+      } {
+        leftSidePartitions += leftSidePartition
+        rightSidePartitions += rightSidePartition
       }
+    }
+    logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
+    if (numSkewedLeft > 0 || numSkewedRight > 0) {
+      Some((CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq),
+        CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq)))
+    } else {
+      None
+    }
+  }
 

Review comment:
       nit: It looks a unnecessary change.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
##########
@@ -52,16 +51,6 @@ case class SortMergeJoinExec(
 
   override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator
 
-  override def requiredChildDistribution: Seq[Distribution] = {

Review comment:
       Could you move `nodeName`, too?
   ```
     override def nodeName: String = {
       if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
     }
   ```

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
##########
@@ -685,64 +691,82 @@ class AdaptiveQueryExecSuite
   }
 
   test("SPARK-29544: adaptive skew join with different join types") {
-    withSQLConf(
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
-      SQLConf.SHUFFLE_PARTITIONS.key -> "100",
-      SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
-      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
-      withTempView("skewData1", "skewData2") {
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .when('id >= 750, 1000)
-              .otherwise('id).as("key1"),
-            'id as "value1")
-          .createOrReplaceTempView("skewData1")
-        spark
-          .range(0, 1000, 1, 10)
-          .select(
-            when('id < 250, 249)
-              .otherwise('id).as("key2"),
-            'id as "value2")
-          .createOrReplaceTempView("skewData2")
+    Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
+      val isSMJ = joinHint == "SHUFFLE_MERGE"
+      withSQLConf(
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+        SQLConf.SHUFFLE_PARTITIONS.key -> "100",
+        SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
+        SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
+        withTempView("skewData1", "skewData2") {
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .when('id >= 750, 1000)
+                .otherwise('id).as("key1"),
+              'id as "value1")
+            .createOrReplaceTempView("skewData1")
+          spark
+            .range(0, 1000, 1, 10)
+            .select(
+              when('id < 250, 249)
+                .otherwise('id).as("key2"),
+              'id as "value2")
+            .createOrReplaceTempView("skewData2")
 
-        def checkSkewJoin(
-            joins: Seq[SortMergeJoinExec],
-            leftSkewNum: Int,
-            rightSkewNum: Int): Unit = {
-          assert(joins.size == 1 && joins.head.isSkewJoin)
-          assert(joins.head.left.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == leftSkewNum)
-          assert(joins.head.right.collect {
-            case r: CustomShuffleReaderExec => r
-          }.head.partitionSpecs.collect {
-            case p: PartialReducerPartitionSpec => p.reducerIndex
-          }.distinct.length == rightSkewNum)
-        }
+          def checkSkewJoin(
+              joins: Seq[ShuffledJoin],
+              leftSkewNum: Int,
+              rightSkewNum: Int): Unit = {
+            assert(joins.size == 1 && joins.head.isSkewJoin)
+            assert(joins.head.left.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == leftSkewNum)
+            assert(joins.head.right.collect {
+              case r: CustomShuffleReaderExec => r
+            }.head.partitionSpecs.collect {
+              case p: PartialReducerPartitionSpec => p.reducerIndex
+            }.distinct.length == rightSkewNum)
+          }
 
-        // skewed inner join optimization
-        val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
-        val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
-        checkSkewJoin(innerSmj, 2, 1)
-
-        // skewed left outer join optimization
-        val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
-        val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
-        checkSkewJoin(leftSmj, 2, 0)
-
-        // skewed right outer join optimization
-        val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
-          "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
-        val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
-        checkSkewJoin(rightSmj, 0, 1)
+          // skewed inner join optimization
+          val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+              s"JOIN skewData2 ON key1 = key2")
+          val inner = if (isSMJ) {
+            findTopLevelSortMergeJoin(innerAdaptivePlan)
+          } else {
+            findTopLevelShuffledHashJoin(innerAdaptivePlan)
+          }
+          checkSkewJoin(inner, 2, 1)
+
+          // skewed left outer join optimization
+          val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
+            s"SELECT /*+ $joinHint(skewData2) */ * FROM skewData1 " +
+              s"LEFT OUTER JOIN skewData2 ON key1 = key2")

Review comment:
       ditto




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org