You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/11/15 11:35:39 UTC

[spark] branch master updated: [SPARK-35352][SQL] Add code-gen for full outer sort merge join

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ef60f72 [SPARK-35352][SQL] Add code-gen for full outer sort merge join
2ef60f72 is described below

commit 2ef60f726c79349cbcda6f34f3e99b32951388bf
Author: Cheng Su <ch...@fb.com>
AuthorDate: Mon Nov 15 19:34:52 2021 +0800

    [SPARK-35352][SQL] Add code-gen for full outer sort merge join
    
    ### What changes were proposed in this pull request?
    
    This PR is to add code-gen for FULL OUTER sort merge join. The change is in `SortMergeJoinExec.scala:codegenFullOuter()`. Followed the same algorithm in iterator mode - `SortMergeFullOuterJoinScanner`: maintain buffer for join left and right sides, and iterate over matched rows in buffers.
    
    Example query:
    
    ```
    val df1 = spark.range(5).select($"id".as("k1"))
    val df2 = spark.range(10).select($"id".as("k2"))
    df1.join(df2.hint(hint), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
    ```
    
    Example generated code: https://gist.github.com/c21/5cab9751f24ae448d77a259d28cb77d7
    
    In addition, to help review as this PR triggers several TPCDS plan files change. The below files are having the real code change:
    
    * `SortMergeJoinExec.scala`
    * `WholeStageCodegenSuite.scala`
    
    All other files are auto-generated golden file plan changes for TPCDS queries.
    
    ### Why are the changes needed?
    
    Improve the run-time/CPU performance of FULL OUTER sort merge join.
    
    Micro benchmark (same query in `JoinBenchmark.scala`):
    
    ```
      def sortMergeJoin(): Unit = {
        val N = 2 << 20
        codegenBenchmark("sort merge join", N) {
          val df1 = spark.range(N).selectExpr(s"id * 2 as k1")
          val df2 = spark.range(N).selectExpr(s"id * 3 as k2")
          val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
          assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
          df.noop()
        }
      }
    
      def sortMergeJoinWithDuplicates(): Unit = {
        val N = 2 << 20
        codegenBenchmark("sort merge join with duplicates", N) {
          val df1 = spark.range(N)
            .selectExpr(s"(id * 15485863) % ${N*10} as k1")
          val df2 = spark.range(N)
            .selectExpr(s"(id * 15485867) % ${N*10} as k2")
          val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
          assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
          df.noop()
        }
      }
    ```
    
    Seeing 20-30% of run-time improvement:
    
    ```
    Running benchmark: sort merge join
      Running case: sort merge join wholestage off
      Stopped after 2 iterations, 2979 ms
      Running case: sort merge join wholestage on
      Stopped after 5 iterations, 5849 ms
    
    Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
    Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
    sort merge join:                          Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    ------------------------------------------------------------------------------------------------------------------------
    sort merge join wholestage off                     1453           1490          52          1.4         693.0       1.0X
    sort merge join wholestage on                      1115           1170          43          1.9         531.6       1.3X
    
    Running benchmark: sort merge join with duplicates
      Running case: sort merge join with duplicates wholestage off
      Stopped after 2 iterations, 3236 ms
      Running case: sort merge join with duplicates wholestage on
      Stopped after 5 iterations, 6768 ms
    
    Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
    Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
    sort merge join with duplicates:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    ------------------------------------------------------------------------------------------------------------------------------
    sort merge join with duplicates wholestage off           1609           1618          13          1.3         767.2       1.0X
    sort merge join with duplicates wholestage on            1330           1354          24          1.6         634.4       1.2X
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    * Added unit test in `WholeStageCodegenSuite.scala`.
    * Existing unit test in `OuterJoinSuite.scala`.
    
    Closes #34581 from c21/smj-codegen.
    
    Authored-by: Cheng Su <ch...@fb.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/execution/joins/SortMergeJoinExec.scala    | 256 ++++++++++++++++++++-
 .../approved-plans-v1_4/q51.sf100/explain.txt      |   4 +-
 .../approved-plans-v1_4/q51.sf100/simplified.txt   |   5 +-
 .../approved-plans-v1_4/q51/explain.txt            |   4 +-
 .../approved-plans-v1_4/q51/simplified.txt         |   5 +-
 .../approved-plans-v1_4/q97.sf100/explain.txt      |   4 +-
 .../approved-plans-v1_4/q97.sf100/simplified.txt   |   5 +-
 .../approved-plans-v1_4/q97/explain.txt            |   4 +-
 .../approved-plans-v1_4/q97/simplified.txt         |   5 +-
 .../approved-plans-v2_7/q51a.sf100/explain.txt     |   4 +-
 .../approved-plans-v2_7/q51a.sf100/simplified.txt  |   5 +-
 .../approved-plans-v2_7/q51a/explain.txt           |   4 +-
 .../approved-plans-v2_7/q51a/simplified.txt        |   5 +-
 .../sql/execution/WholeStageCodegenSuite.scala     |  82 ++++---
 14 files changed, 327 insertions(+), 65 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 66054bf..afed14a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -364,7 +364,8 @@ case class SortMergeJoinExec(
   }
 
   private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
-    case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => ((left, leftKeys), (right, rightKeys))
+    case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter =>
+      ((left, leftKeys), (right, rightKeys))
     case RightOuter => ((right, rightKeys), (left, leftKeys))
     case x =>
       throw new IllegalArgumentException(
@@ -374,9 +375,10 @@ case class SortMergeJoinExec(
   private lazy val streamedOutput = streamedPlan.output
   private lazy val bufferedOutput = bufferedPlan.output
 
+  // TODO(SPARK-37316): Add code-gen for existence sort merge join.
   override def supportCodegen: Boolean = joinType match {
-    case _: InnerLike | LeftOuter | RightOuter | LeftSemi | LeftAnti => true
-    case _ => false
+    case _: ExistenceJoin => false
+    case _ => true
   }
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -644,6 +646,12 @@ case class SortMergeJoinExec(
   override def needCopyResult: Boolean = true
 
   override def doProduce(ctx: CodegenContext): String = {
+    // Specialize `doProduce` code for full outer join, because full outer join needs to
+    // buffer both sides of join.
+    if (joinType == FullOuter) {
+      return codegenFullOuter(ctx)
+    }
+
     // Inline mutable state since not many join operations in a task
     val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
       v => s"$v = inputs[0];", forceInline = true)
@@ -890,6 +898,248 @@ case class SortMergeJoinExec(
      """.stripMargin
   }
 
+  /**
+   * Generates the code for Full Outer join.
+   */
+  private def codegenFullOuter(ctx: CodegenContext): String = {
+    // Inline mutable state since not many join operations in a task.
+    // Create class member for input iterator from both sides.
+    val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
+      v => s"$v = inputs[0];", forceInline = true)
+    val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
+      v => s"$v = inputs[1];", forceInline = true)
+
+    // Create class member for next input row from both sides.
+    val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true)
+    val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true)
+
+    // Create variables for join keys from both sides.
+    val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
+    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+    val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
+    val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
+    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
+    val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
+    val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
+
+    // Create class member for next output row from both sides.
+    val leftOutputRow = ctx.addMutableState("InternalRow", "leftOutputRow", forceInline = true)
+    val rightOutputRow = ctx.addMutableState("InternalRow", "rightOutputRow", forceInline = true)
+
+    // Create class member for buffers of rows with same join keys from both sides.
+    val bufferClsName = "java.util.ArrayList<InternalRow>"
+    val leftBuffer = ctx.addMutableState(bufferClsName, "leftBuffer",
+      v => s"$v = new $bufferClsName();", forceInline = true)
+    val rightBuffer = ctx.addMutableState(bufferClsName, "rightBuffer",
+      v => s"$v = new $bufferClsName();", forceInline = true)
+    val matchedClsName = classOf[BitSet].getName
+    val leftMatched = ctx.addMutableState(matchedClsName, "leftMatched",
+      v => s"$v = new $matchedClsName(1);", forceInline = true)
+    val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched",
+      v => s"$v = new $matchedClsName(1);", forceInline = true)
+    val leftIndex = ctx.freshName("leftIndex")
+    val rightIndex = ctx.freshName("rightIndex")
+
+    // Generate code for join condition
+    val leftResultVars = genOneSideJoinVars(
+      ctx, leftOutputRow, left, setDefaultValue = true)
+    val rightResultVars = genOneSideJoinVars(
+      ctx, rightOutputRow, right, setDefaultValue = true)
+    val resultVars = leftResultVars ++ rightResultVars
+    val (_, conditionCheck, _) =
+      getJoinCondition(ctx, leftResultVars, left, right, Some(rightOutputRow))
+
+    // Generate code for result output in separate function, as we need to output result from
+    // multiple places in join code.
+    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
+    ctx.addNewFunction(consumeFullOuterJoinRow,
+      s"""
+         |private void $consumeFullOuterJoinRow() throws java.io.IOException {
+         |  ${metricTerm(ctx, "numOutputRows")}.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
+       """.stripMargin)
+
+    // Handle the case when input row has no match.
+    val outputLeftNoMatch =
+      s"""
+         |$leftOutputRow = $leftInputRow;
+         |$rightOutputRow = null;
+         |$leftInputRow = null;
+         |$consumeFullOuterJoinRow();
+       """.stripMargin
+    val outputRightNoMatch =
+      s"""
+         |$rightOutputRow = $rightInputRow;
+         |$leftOutputRow = null;
+         |$rightInputRow = null;
+         |$consumeFullOuterJoinRow();
+       """.stripMargin
+
+    // Generate a function to scan both sides to find rows with matched join keys.
+    // The matched rows from both sides are copied in buffers separately. This function assumes
+    // either non-empty `leftIter` and `rightIter`, or non-null `leftInputRow` and `rightInputRow`.
+    //
+    // The function has the following steps:
+    //  - Step 1: Find the next `leftInputRow` and `rightInputRow` with non-null join keys.
+    //            Output row with null join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
+    //
+    //  - Step 2: Compare and find next same join keys from between `leftInputRow` and
+    //            `rightInputRow`.
+    //            Output row with smaller join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
+    //
+    //  - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and
+    //            `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and
+    //            `rightMatched`).
+    val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
+    ctx.addNewFunction(findNextJoinRowsFuncName,
+      s"""
+         |private void $findNextJoinRowsFuncName(
+         |    scala.collection.Iterator leftIter,
+         |    scala.collection.Iterator rightIter) throws java.io.IOException {
+         |  int comp = 0;
+         |  $leftBuffer.clear();
+         |  $rightBuffer.clear();
+         |
+         |  if ($leftInputRow == null) {
+         |    $leftInputRow = (InternalRow) leftIter.next();
+         |  }
+         |  if ($rightInputRow == null) {
+         |    $rightInputRow = (InternalRow) rightIter.next();
+         |  }
+         |
+         |  ${leftKeyVars.map(_.code).mkString("\n")}
+         |  if ($leftAnyNull) {
+         |    // The left row join key is null, join it with null row
+         |    $outputLeftNoMatch
+         |    return;
+         |  }
+         |
+         |  ${rightKeyVars.map(_.code).mkString("\n")}
+         |  if ($rightAnyNull) {
+         |    // The right row join key is null, join it with null row
+         |    $outputRightNoMatch
+         |    return;
+         |  }
+         |
+         |  ${genComparison(ctx, leftKeyVars, rightKeyVars)}
+         |  if (comp < 0) {
+         |    // The left row join key is smaller, join it with null row
+         |    $outputLeftNoMatch
+         |    return;
+         |  } else if (comp > 0) {
+         |    // The right row join key is smaller, join it with null row
+         |    $outputRightNoMatch
+         |    return;
+         |  }
+         |
+         |  ${matchedKeyVars.map(_.code).mkString("\n")}
+         |  $leftBuffer.add($leftInputRow.copy());
+         |  $rightBuffer.add($rightInputRow.copy());
+         |  $leftInputRow = null;
+         |  $rightInputRow = null;
+         |
+         |  // Buffer rows from both sides with same join key
+         |  while (leftIter.hasNext()) {
+         |    $leftInputRow = (InternalRow) leftIter.next();
+         |    ${leftMatchedKeyVars.map(_.code).mkString("\n")}
+         |    ${genComparison(ctx, leftMatchedKeyVars, matchedKeyVars)}
+         |    if (comp == 0) {
+         |
+         |      $leftBuffer.add($leftInputRow.copy());
+         |      $leftInputRow = null;
+         |    } else {
+         |      break;
+         |    }
+         |  }
+         |  while (rightIter.hasNext()) {
+         |    $rightInputRow = (InternalRow) rightIter.next();
+         |    ${rightMatchedKeyVars.map(_.code).mkString("\n")}
+         |    ${genComparison(ctx, rightMatchedKeyVars, matchedKeyVars)}
+         |    if (comp == 0) {
+         |      $rightBuffer.add($rightInputRow.copy());
+         |      $rightInputRow = null;
+         |    } else {
+         |      break;
+         |    }
+         |  }
+         |
+         |  // Reset bit sets of buffers accordingly
+         |  if ($leftBuffer.size() <= $leftMatched.capacity()) {
+         |    $leftMatched.clearUntil($leftBuffer.size());
+         |  } else {
+         |    $leftMatched = new $matchedClsName($leftBuffer.size());
+         |  }
+         |  if ($rightBuffer.size() <= $rightMatched.capacity()) {
+         |    $rightMatched.clearUntil($rightBuffer.size());
+         |  } else {
+         |    $rightMatched = new $matchedClsName($rightBuffer.size());
+         |  }
+         |}
+       """.stripMargin)
+
+    // Scan the left and right buffers to find all matched rows.
+    val matchRowsInBuffer =
+      s"""
+         |int $leftIndex;
+         |int $rightIndex;
+         |
+         |for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) {
+         |  $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex);
+         |  for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
+         |    $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
+         |    $conditionCheck {
+         |      $consumeFullOuterJoinRow();
+         |      $leftMatched.set($leftIndex);
+         |      $rightMatched.set($rightIndex);
+         |    }
+         |  }
+         |
+         |  if (!$leftMatched.get($leftIndex)) {
+         |
+         |    $rightOutputRow = null;
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |}
+         |
+         |$leftOutputRow = null;
+         |for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
+         |  if (!$rightMatched.get($rightIndex)) {
+         |    // The right row has never matched any left row, join it with null row
+         |    $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |}
+       """.stripMargin
+
+    s"""
+       |while (($leftInputRow != null || $leftInput.hasNext()) &&
+       |  ($rightInputRow != null || $rightInput.hasNext())) {
+       |  $findNextJoinRowsFuncName($leftInput, $rightInput);
+       |  $matchRowsInBuffer
+       |  if (shouldStop()) return;
+       |}
+       |
+       |// The right iterator has no more rows, join left row with null
+       |while ($leftInputRow != null || $leftInput.hasNext()) {
+       |  if ($leftInputRow == null) {
+       |    $leftInputRow = (InternalRow) $leftInput.next();
+       |  }
+       |  $outputLeftNoMatch
+       |  if (shouldStop()) return;
+       |}
+       |
+       |// The left iterator has no more rows, join right row with null
+       |while ($rightInputRow != null || $rightInput.hasNext()) {
+       |  if ($rightInputRow == null) {
+       |    $rightInputRow = (InternalRow) $rightInput.next();
+       |  }
+       |  $outputRightNoMatch
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
+
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
     copy(left = newLeft, right = newRight)
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
index 51b1ae5..cbb189e 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
@@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
       +- * Sort (34)
          +- Exchange (33)
             +- * Project (32)
-               +- SortMergeJoin FullOuter (31)
+               +- * SortMergeJoin FullOuter (31)
                   :- * Sort (15)
                   :  +- Exchange (14)
                   :     +- * Project (13)
@@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
 Input [3]: [item_sk#25, d_date#20, cume_sales#28]
 Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0
 
-(31) SortMergeJoin
+(31) SortMergeJoin [codegen id : 13]
 Left keys [2]: [item_sk#11, d_date#6]
 Right keys [2]: [item_sk#25, d_date#20]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
index 38d3f50..489aab1 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
@@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                 Exchange [item_sk] #1
                   WholeStageCodegen (13)
                     Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
-                      InputAdapter
-                        SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                      SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                        InputAdapter
                           WholeStageCodegen (6)
                             Sort [item_sk,d_date]
                               InputAdapter
@@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                                                                                         Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
                                                                   InputAdapter
                                                                     ReusedExchange [d_date_sk,d_date] #5
+                        InputAdapter
                           WholeStageCodegen (12)
                             Sort [item_sk,d_date]
                               InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
index 51b1ae5..cbb189e 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
@@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
       +- * Sort (34)
          +- Exchange (33)
             +- * Project (32)
-               +- SortMergeJoin FullOuter (31)
+               +- * SortMergeJoin FullOuter (31)
                   :- * Sort (15)
                   :  +- Exchange (14)
                   :     +- * Project (13)
@@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
 Input [3]: [item_sk#25, d_date#20, cume_sales#28]
 Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0
 
-(31) SortMergeJoin
+(31) SortMergeJoin [codegen id : 13]
 Left keys [2]: [item_sk#11, d_date#6]
 Right keys [2]: [item_sk#25, d_date#20]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
index 38d3f50..489aab1 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
@@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                 Exchange [item_sk] #1
                   WholeStageCodegen (13)
                     Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
-                      InputAdapter
-                        SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                      SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                        InputAdapter
                           WholeStageCodegen (6)
                             Sort [item_sk,d_date]
                               InputAdapter
@@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                                                                                         Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
                                                                   InputAdapter
                                                                     ReusedExchange [d_date_sk,d_date] #5
+                        InputAdapter
                           WholeStageCodegen (12)
                             Sort [item_sk,d_date]
                               InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
index e9e97e9..e47aaf2 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
@@ -3,7 +3,7 @@
 +- Exchange (22)
    +- * HashAggregate (21)
       +- * Project (20)
-         +- SortMergeJoin FullOuter (19)
+         +- * SortMergeJoin FullOuter (19)
             :- * Sort (9)
             :  +- * HashAggregate (8)
             :     +- Exchange (7)
@@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk#
 Input [2]: [customer_sk#14, item_sk#15]
 Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0
 
-(19) SortMergeJoin
+(19) SortMergeJoin [codegen id : 7]
 Left keys [2]: [customer_sk#7, item_sk#8]
 Right keys [2]: [customer_sk#14, item_sk#15]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
index 227b3c6..99c8a1d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
@@ -5,8 +5,8 @@ WholeStageCodegen (8)
         WholeStageCodegen (7)
           HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum]
             Project [customer_sk,customer_sk]
-              InputAdapter
-                SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+              SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+                InputAdapter
                   WholeStageCodegen (3)
                     Sort [customer_sk,item_sk]
                       HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk]
@@ -29,6 +29,7 @@ WholeStageCodegen (8)
                                                         Scan parquet default.date_dim [d_date_sk,d_month_seq]
                                     InputAdapter
                                       ReusedExchange [d_date_sk] #3
+                InputAdapter
                   WholeStageCodegen (6)
                     Sort [customer_sk,item_sk]
                       HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
index e9e97e9..e47aaf2 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
@@ -3,7 +3,7 @@
 +- Exchange (22)
    +- * HashAggregate (21)
       +- * Project (20)
-         +- SortMergeJoin FullOuter (19)
+         +- * SortMergeJoin FullOuter (19)
             :- * Sort (9)
             :  +- * HashAggregate (8)
             :     +- Exchange (7)
@@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk#
 Input [2]: [customer_sk#14, item_sk#15]
 Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0
 
-(19) SortMergeJoin
+(19) SortMergeJoin [codegen id : 7]
 Left keys [2]: [customer_sk#7, item_sk#8]
 Right keys [2]: [customer_sk#14, item_sk#15]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
index 227b3c6..99c8a1d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
@@ -5,8 +5,8 @@ WholeStageCodegen (8)
         WholeStageCodegen (7)
           HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum]
             Project [customer_sk,customer_sk]
-              InputAdapter
-                SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+              SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+                InputAdapter
                   WholeStageCodegen (3)
                     Sort [customer_sk,item_sk]
                       HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk]
@@ -29,6 +29,7 @@ WholeStageCodegen (8)
                                                         Scan parquet default.date_dim [d_date_sk,d_month_seq]
                                     InputAdapter
                                       ReusedExchange [d_date_sk] #3
+                InputAdapter
                   WholeStageCodegen (6)
                     Sort [customer_sk,item_sk]
                       HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
index 740ea0f..64111ee 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
@@ -10,7 +10,7 @@ TakeOrderedAndProject (70)
                :     +- Exchange (58)
                :        +- * Project (57)
                :           +- * Filter (56)
-               :              +- SortMergeJoin FullOuter (55)
+               :              +- * SortMergeJoin FullOuter (55)
                :                 :- * Sort (27)
                :                 :  +- Exchange (26)
                :                 :     +- * HashAggregate (25)
@@ -317,7 +317,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id=
 Input [3]: [item_sk#38, d_date#33, cume_sales#54]
 Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0
 
-(55) SortMergeJoin
+(55) SortMergeJoin [codegen id : 29]
 Left keys [2]: [item_sk#11, d_date#6]
 Right keys [2]: [item_sk#38, d_date#33]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
index ed52e97..1a89b7c 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
@@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                           WholeStageCodegen (29)
                             Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
                               Filter [item_sk,item_sk]
-                                InputAdapter
-                                  SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                                SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                                  InputAdapter
                                     WholeStageCodegen (14)
                                       Sort [item_sk,d_date]
                                         InputAdapter
@@ -73,6 +73,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                                                                           Sort [ws_item_sk,d_date]
                                                                             InputAdapter
                                                                               ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4
+                                  InputAdapter
                                     WholeStageCodegen (28)
                                       Sort [item_sk,d_date]
                                         InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
index cf86cd6..9edb377 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
@@ -10,7 +10,7 @@ TakeOrderedAndProject (67)
                :     +- Exchange (54)
                :        +- * Project (53)
                :           +- * Filter (52)
-               :              +- SortMergeJoin FullOuter (51)
+               :              +- * SortMergeJoin FullOuter (51)
                :                 :- * Sort (25)
                :                 :  +- Exchange (24)
                :                 :     +- * HashAggregate (23)
@@ -298,7 +298,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id=
 Input [3]: [item_sk#38, d_date#33, cume_sales#54]
 Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0
 
-(51) SortMergeJoin
+(51) SortMergeJoin [codegen id : 25]
 Left keys [2]: [item_sk#11, d_date#6]
 Right keys [2]: [item_sk#38, d_date#33]
 Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
index a99caa4..d6612db 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
@@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                           WholeStageCodegen (25)
                             Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
                               Filter [item_sk,item_sk]
-                                InputAdapter
-                                  SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                                SortMergeJoin [item_sk,d_date,item_sk,d_date]
+                                  InputAdapter
                                     WholeStageCodegen (12)
                                       Sort [item_sk,d_date]
                                         InputAdapter
@@ -67,6 +67,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
                                                                           Sort [ws_item_sk,d_date]
                                                                             InputAdapter
                                                                               ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4
+                                  InputAdapter
                                     WholeStageCodegen (24)
                                       Sort [item_sk,d_date]
                                         InputAdapter
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f483971..00ea371 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -171,48 +171,54 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
       Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
   }
 
-  test("Full Outer ShuffledHashJoin should be included in WholeStageCodegen") {
+  test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") {
     val df1 = spark.range(5).select($"id".as("k1"))
     val df2 = spark.range(10).select($"id".as("k2"))
     val df3 = spark.range(3).select($"id".as("k3"))
 
-    // test one join with unique key from build side
-    val joinUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer")
-    assert(joinUniqueDF.queryExecution.executedPlan.collect {
-      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
-    }.size === 1)
-    checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
-      Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
-    assert(joinUniqueDF.count() === 10)
-
-    // test one join with non-unique key from build side
-    val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3, "full_outer")
-    assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
-      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
-    }.size === 1)
-    checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
-      Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
-
-    // test one join with non-equi condition
-    val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"),
-      $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
-    assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
-      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
-    }.size === 1)
-    checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
-      Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
-      Row(null, 5)))
-
-    // test two joins
-    val twoJoinsDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer")
-      .join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer")
-    assert(twoJoinsDF.queryExecution.executedPlan.collect {
-      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
-    }.size === 2)
-    checkAnswer(twoJoinsDF,
-      Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
-        Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null),
-        Row(null, 9, null), Row(null, null, 1)))
+    Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint =>
+      // test one join with unique key from build side
+      val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+      assert(joinUniqueDF.queryExecution.executedPlan.collect {
+        case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+        case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+      }.size === 1)
+      checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
+        Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
+      assert(joinUniqueDF.count() === 10)
+
+      // test one join with non-unique key from build side
+      val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer")
+      assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
+        case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+        case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+      }.size === 1)
+      checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
+        Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
+
+      // test one join with non-equi condition
+      val joinWithNonEquiDF = df1.join(df2.hint(hint),
+        $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
+      assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
+        case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+        case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+      }.size === 1)
+      checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
+        Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
+        Row(null, 5)))
+
+      // test two joins
+      val twoJoinsDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+        .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer")
+      assert(twoJoinsDF.queryExecution.executedPlan.collect {
+        case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+        case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+      }.size === 2)
+      checkAnswer(twoJoinsDF,
+        Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
+          Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null),
+          Row(null, 9, null), Row(null, null, 1)))
+    }
   }
 
   test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org