You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/11/15 11:35:39 UTC
[spark] branch master updated: [SPARK-35352][SQL] Add code-gen for full outer sort merge join
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2ef60f72 [SPARK-35352][SQL] Add code-gen for full outer sort merge join
2ef60f72 is described below
commit 2ef60f726c79349cbcda6f34f3e99b32951388bf
Author: Cheng Su <ch...@fb.com>
AuthorDate: Mon Nov 15 19:34:52 2021 +0800
[SPARK-35352][SQL] Add code-gen for full outer sort merge join
### What changes were proposed in this pull request?
This PR is to add code-gen for FULL OUTER sort merge join. The change is in `SortMergeJoinExec.scala:codegenFullOuter()`. Followed the same algorithm in iterator mode - `SortMergeFullOuterJoinScanner`: maintain buffer for join left and right sides, and iterate over matched rows in buffers.
Example query:
```
val df1 = spark.range(5).select($"id".as("k1"))
val df2 = spark.range(10).select($"id".as("k2"))
df1.join(df2.hint(hint), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
```
Example generated code: https://gist.github.com/c21/5cab9751f24ae448d77a259d28cb77d7
In addition, to help review as this PR triggers several TPCDS plan files change. The below files are having the real code change:
* `SortMergeJoinExec.scala`
* `WholeStageCodegenSuite.scala`
All other files are auto-generated golden file plan changes for TPCDS queries.
### Why are the changes needed?
Improve the run-time/CPU performance of FULL OUTER sort merge join.
Micro benchmark (same query in `JoinBenchmark.scala`):
```
def sortMergeJoin(): Unit = {
val N = 2 << 20
codegenBenchmark("sort merge join", N) {
val df1 = spark.range(N).selectExpr(s"id * 2 as k1")
val df2 = spark.range(N).selectExpr(s"id * 3 as k2")
val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
df.noop()
}
}
def sortMergeJoinWithDuplicates(): Unit = {
val N = 2 << 20
codegenBenchmark("sort merge join with duplicates", N) {
val df1 = spark.range(N)
.selectExpr(s"(id * 15485863) % ${N*10} as k1")
val df2 = spark.range(N)
.selectExpr(s"(id * 15485867) % ${N*10} as k2")
val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
df.noop()
}
}
```
Seeing 20-30% of run-time improvement:
```
Running benchmark: sort merge join
Running case: sort merge join wholestage off
Stopped after 2 iterations, 2979 ms
Running case: sort merge join wholestage on
Stopped after 5 iterations, 5849 ms
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
Intel(R) Core(TM) i9-9980HK CPU 2.40GHz
sort merge join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
sort merge join wholestage off 1453 1490 52 1.4 693.0 1.0X
sort merge join wholestage on 1115 1170 43 1.9 531.6 1.3X
Running benchmark: sort merge join with duplicates
Running case: sort merge join with duplicates wholestage off
Stopped after 2 iterations, 3236 ms
Running case: sort merge join with duplicates wholestage on
Stopped after 5 iterations, 6768 ms
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
Intel(R) Core(TM) i9-9980HK CPU 2.40GHz
sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------
sort merge join with duplicates wholestage off 1609 1618 13 1.3 767.2 1.0X
sort merge join with duplicates wholestage on 1330 1354 24 1.6 634.4 1.2X
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
* Added unit test in `WholeStageCodegenSuite.scala`.
* Existing unit test in `OuterJoinSuite.scala`.
Closes #34581 from c21/smj-codegen.
Authored-by: Cheng Su <ch...@fb.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/execution/joins/SortMergeJoinExec.scala | 256 ++++++++++++++++++++-
.../approved-plans-v1_4/q51.sf100/explain.txt | 4 +-
.../approved-plans-v1_4/q51.sf100/simplified.txt | 5 +-
.../approved-plans-v1_4/q51/explain.txt | 4 +-
.../approved-plans-v1_4/q51/simplified.txt | 5 +-
.../approved-plans-v1_4/q97.sf100/explain.txt | 4 +-
.../approved-plans-v1_4/q97.sf100/simplified.txt | 5 +-
.../approved-plans-v1_4/q97/explain.txt | 4 +-
.../approved-plans-v1_4/q97/simplified.txt | 5 +-
.../approved-plans-v2_7/q51a.sf100/explain.txt | 4 +-
.../approved-plans-v2_7/q51a.sf100/simplified.txt | 5 +-
.../approved-plans-v2_7/q51a/explain.txt | 4 +-
.../approved-plans-v2_7/q51a/simplified.txt | 5 +-
.../sql/execution/WholeStageCodegenSuite.scala | 82 ++++---
14 files changed, 327 insertions(+), 65 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 66054bf..afed14a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -364,7 +364,8 @@ case class SortMergeJoinExec(
}
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
- case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => ((left, leftKeys), (right, rightKeys))
+ case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter =>
+ ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
@@ -374,9 +375,10 @@ case class SortMergeJoinExec(
private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output
+ // TODO(SPARK-37316): Add code-gen for existence sort merge join.
override def supportCodegen: Boolean = joinType match {
- case _: InnerLike | LeftOuter | RightOuter | LeftSemi | LeftAnti => true
- case _ => false
+ case _: ExistenceJoin => false
+ case _ => true
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -644,6 +646,12 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true
override def doProduce(ctx: CodegenContext): String = {
+ // Specialize `doProduce` code for full outer join, because full outer join needs to
+ // buffer both sides of join.
+ if (joinType == FullOuter) {
+ return codegenFullOuter(ctx)
+ }
+
// Inline mutable state since not many join operations in a task
val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
v => s"$v = inputs[0];", forceInline = true)
@@ -890,6 +898,248 @@ case class SortMergeJoinExec(
""".stripMargin
}
+ /**
+ * Generates the code for Full Outer join.
+ */
+ private def codegenFullOuter(ctx: CodegenContext): String = {
+ // Inline mutable state since not many join operations in a task.
+ // Create class member for input iterator from both sides.
+ val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
+ v => s"$v = inputs[0];", forceInline = true)
+ val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
+ v => s"$v = inputs[1];", forceInline = true)
+
+ // Create class member for next input row from both sides.
+ val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true)
+ val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true)
+
+ // Create variables for join keys from both sides.
+ val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
+ val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+ val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
+ val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
+ val matchedKeyVars = copyKeys(ctx, leftKeyVars)
+ val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
+ val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
+
+ // Create class member for next output row from both sides.
+ val leftOutputRow = ctx.addMutableState("InternalRow", "leftOutputRow", forceInline = true)
+ val rightOutputRow = ctx.addMutableState("InternalRow", "rightOutputRow", forceInline = true)
+
+ // Create class member for buffers of rows with same join keys from both sides.
+ val bufferClsName = "java.util.ArrayList<InternalRow>"
+ val leftBuffer = ctx.addMutableState(bufferClsName, "leftBuffer",
+ v => s"$v = new $bufferClsName();", forceInline = true)
+ val rightBuffer = ctx.addMutableState(bufferClsName, "rightBuffer",
+ v => s"$v = new $bufferClsName();", forceInline = true)
+ val matchedClsName = classOf[BitSet].getName
+ val leftMatched = ctx.addMutableState(matchedClsName, "leftMatched",
+ v => s"$v = new $matchedClsName(1);", forceInline = true)
+ val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched",
+ v => s"$v = new $matchedClsName(1);", forceInline = true)
+ val leftIndex = ctx.freshName("leftIndex")
+ val rightIndex = ctx.freshName("rightIndex")
+
+ // Generate code for join condition
+ val leftResultVars = genOneSideJoinVars(
+ ctx, leftOutputRow, left, setDefaultValue = true)
+ val rightResultVars = genOneSideJoinVars(
+ ctx, rightOutputRow, right, setDefaultValue = true)
+ val resultVars = leftResultVars ++ rightResultVars
+ val (_, conditionCheck, _) =
+ getJoinCondition(ctx, leftResultVars, left, right, Some(rightOutputRow))
+
+ // Generate code for result output in separate function, as we need to output result from
+ // multiple places in join code.
+ val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
+ ctx.addNewFunction(consumeFullOuterJoinRow,
+ s"""
+ |private void $consumeFullOuterJoinRow() throws java.io.IOException {
+ | ${metricTerm(ctx, "numOutputRows")}.add(1);
+ | ${consume(ctx, resultVars)}
+ |}
+ """.stripMargin)
+
+ // Handle the case when input row has no match.
+ val outputLeftNoMatch =
+ s"""
+ |$leftOutputRow = $leftInputRow;
+ |$rightOutputRow = null;
+ |$leftInputRow = null;
+ |$consumeFullOuterJoinRow();
+ """.stripMargin
+ val outputRightNoMatch =
+ s"""
+ |$rightOutputRow = $rightInputRow;
+ |$leftOutputRow = null;
+ |$rightInputRow = null;
+ |$consumeFullOuterJoinRow();
+ """.stripMargin
+
+ // Generate a function to scan both sides to find rows with matched join keys.
+ // The matched rows from both sides are copied in buffers separately. This function assumes
+ // either non-empty `leftIter` and `rightIter`, or non-null `leftInputRow` and `rightInputRow`.
+ //
+ // The function has the following steps:
+ // - Step 1: Find the next `leftInputRow` and `rightInputRow` with non-null join keys.
+ // Output row with null join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
+ //
+ // - Step 2: Compare and find next same join keys from between `leftInputRow` and
+ // `rightInputRow`.
+ // Output row with smaller join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
+ //
+ // - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and
+ // `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and
+ // `rightMatched`).
+ val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
+ ctx.addNewFunction(findNextJoinRowsFuncName,
+ s"""
+ |private void $findNextJoinRowsFuncName(
+ | scala.collection.Iterator leftIter,
+ | scala.collection.Iterator rightIter) throws java.io.IOException {
+ | int comp = 0;
+ | $leftBuffer.clear();
+ | $rightBuffer.clear();
+ |
+ | if ($leftInputRow == null) {
+ | $leftInputRow = (InternalRow) leftIter.next();
+ | }
+ | if ($rightInputRow == null) {
+ | $rightInputRow = (InternalRow) rightIter.next();
+ | }
+ |
+ | ${leftKeyVars.map(_.code).mkString("\n")}
+ | if ($leftAnyNull) {
+ | // The left row join key is null, join it with null row
+ | $outputLeftNoMatch
+ | return;
+ | }
+ |
+ | ${rightKeyVars.map(_.code).mkString("\n")}
+ | if ($rightAnyNull) {
+ | // The right row join key is null, join it with null row
+ | $outputRightNoMatch
+ | return;
+ | }
+ |
+ | ${genComparison(ctx, leftKeyVars, rightKeyVars)}
+ | if (comp < 0) {
+ | // The left row join key is smaller, join it with null row
+ | $outputLeftNoMatch
+ | return;
+ | } else if (comp > 0) {
+ | // The right row join key is smaller, join it with null row
+ | $outputRightNoMatch
+ | return;
+ | }
+ |
+ | ${matchedKeyVars.map(_.code).mkString("\n")}
+ | $leftBuffer.add($leftInputRow.copy());
+ | $rightBuffer.add($rightInputRow.copy());
+ | $leftInputRow = null;
+ | $rightInputRow = null;
+ |
+ | // Buffer rows from both sides with same join key
+ | while (leftIter.hasNext()) {
+ | $leftInputRow = (InternalRow) leftIter.next();
+ | ${leftMatchedKeyVars.map(_.code).mkString("\n")}
+ | ${genComparison(ctx, leftMatchedKeyVars, matchedKeyVars)}
+ | if (comp == 0) {
+ |
+ | $leftBuffer.add($leftInputRow.copy());
+ | $leftInputRow = null;
+ | } else {
+ | break;
+ | }
+ | }
+ | while (rightIter.hasNext()) {
+ | $rightInputRow = (InternalRow) rightIter.next();
+ | ${rightMatchedKeyVars.map(_.code).mkString("\n")}
+ | ${genComparison(ctx, rightMatchedKeyVars, matchedKeyVars)}
+ | if (comp == 0) {
+ | $rightBuffer.add($rightInputRow.copy());
+ | $rightInputRow = null;
+ | } else {
+ | break;
+ | }
+ | }
+ |
+ | // Reset bit sets of buffers accordingly
+ | if ($leftBuffer.size() <= $leftMatched.capacity()) {
+ | $leftMatched.clearUntil($leftBuffer.size());
+ | } else {
+ | $leftMatched = new $matchedClsName($leftBuffer.size());
+ | }
+ | if ($rightBuffer.size() <= $rightMatched.capacity()) {
+ | $rightMatched.clearUntil($rightBuffer.size());
+ | } else {
+ | $rightMatched = new $matchedClsName($rightBuffer.size());
+ | }
+ |}
+ """.stripMargin)
+
+ // Scan the left and right buffers to find all matched rows.
+ val matchRowsInBuffer =
+ s"""
+ |int $leftIndex;
+ |int $rightIndex;
+ |
+ |for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) {
+ | $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex);
+ | for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
+ | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
+ | $conditionCheck {
+ | $consumeFullOuterJoinRow();
+ | $leftMatched.set($leftIndex);
+ | $rightMatched.set($rightIndex);
+ | }
+ | }
+ |
+ | if (!$leftMatched.get($leftIndex)) {
+ |
+ | $rightOutputRow = null;
+ | $consumeFullOuterJoinRow();
+ | }
+ |}
+ |
+ |$leftOutputRow = null;
+ |for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
+ | if (!$rightMatched.get($rightIndex)) {
+ | // The right row has never matched any left row, join it with null row
+ | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
+ | $consumeFullOuterJoinRow();
+ | }
+ |}
+ """.stripMargin
+
+ s"""
+ |while (($leftInputRow != null || $leftInput.hasNext()) &&
+ | ($rightInputRow != null || $rightInput.hasNext())) {
+ | $findNextJoinRowsFuncName($leftInput, $rightInput);
+ | $matchRowsInBuffer
+ | if (shouldStop()) return;
+ |}
+ |
+ |// The right iterator has no more rows, join left row with null
+ |while ($leftInputRow != null || $leftInput.hasNext()) {
+ | if ($leftInputRow == null) {
+ | $leftInputRow = (InternalRow) $leftInput.next();
+ | }
+ | $outputLeftNoMatch
+ | if (shouldStop()) return;
+ |}
+ |
+ |// The left iterator has no more rows, join right row with null
+ |while ($rightInputRow != null || $rightInput.hasNext()) {
+ | if ($rightInputRow == null) {
+ | $rightInputRow = (InternalRow) $rightInput.next();
+ | }
+ | $outputRightNoMatch
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
+
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
copy(left = newLeft, right = newRight)
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
index 51b1ae5..cbb189e 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt
@@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
+- * Sort (34)
+- Exchange (33)
+- * Project (32)
- +- SortMergeJoin FullOuter (31)
+ +- * SortMergeJoin FullOuter (31)
:- * Sort (15)
: +- Exchange (14)
: +- * Project (13)
@@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#25, d_date#20, cume_sales#28]
Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0
-(31) SortMergeJoin
+(31) SortMergeJoin [codegen id : 13]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#25, d_date#20]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
index 38d3f50..489aab1 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt
@@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Exchange [item_sk] #1
WholeStageCodegen (13)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
- InputAdapter
- SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ InputAdapter
WholeStageCodegen (6)
Sort [item_sk,d_date]
InputAdapter
@@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk,d_date] #5
+ InputAdapter
WholeStageCodegen (12)
Sort [item_sk,d_date]
InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
index 51b1ae5..cbb189e 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt
@@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
+- * Sort (34)
+- Exchange (33)
+- * Project (32)
- +- SortMergeJoin FullOuter (31)
+ +- * SortMergeJoin FullOuter (31)
:- * Sort (15)
: +- Exchange (14)
: +- * Project (13)
@@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#25, d_date#20, cume_sales#28]
Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0
-(31) SortMergeJoin
+(31) SortMergeJoin [codegen id : 13]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#25, d_date#20]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
index 38d3f50..489aab1 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt
@@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Exchange [item_sk] #1
WholeStageCodegen (13)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
- InputAdapter
- SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ InputAdapter
WholeStageCodegen (6)
Sort [item_sk,d_date]
InputAdapter
@@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk,d_date] #5
+ InputAdapter
WholeStageCodegen (12)
Sort [item_sk,d_date]
InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
index e9e97e9..e47aaf2 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt
@@ -3,7 +3,7 @@
+- Exchange (22)
+- * HashAggregate (21)
+- * Project (20)
- +- SortMergeJoin FullOuter (19)
+ +- * SortMergeJoin FullOuter (19)
:- * Sort (9)
: +- * HashAggregate (8)
: +- Exchange (7)
@@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk#
Input [2]: [customer_sk#14, item_sk#15]
Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0
-(19) SortMergeJoin
+(19) SortMergeJoin [codegen id : 7]
Left keys [2]: [customer_sk#7, item_sk#8]
Right keys [2]: [customer_sk#14, item_sk#15]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
index 227b3c6..99c8a1d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt
@@ -5,8 +5,8 @@ WholeStageCodegen (8)
WholeStageCodegen (7)
HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum]
Project [customer_sk,customer_sk]
- InputAdapter
- SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+ SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+ InputAdapter
WholeStageCodegen (3)
Sort [customer_sk,item_sk]
HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk]
@@ -29,6 +29,7 @@ WholeStageCodegen (8)
Scan parquet default.date_dim [d_date_sk,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk] #3
+ InputAdapter
WholeStageCodegen (6)
Sort [customer_sk,item_sk]
HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
index e9e97e9..e47aaf2 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt
@@ -3,7 +3,7 @@
+- Exchange (22)
+- * HashAggregate (21)
+- * Project (20)
- +- SortMergeJoin FullOuter (19)
+ +- * SortMergeJoin FullOuter (19)
:- * Sort (9)
: +- * HashAggregate (8)
: +- Exchange (7)
@@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk#
Input [2]: [customer_sk#14, item_sk#15]
Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0
-(19) SortMergeJoin
+(19) SortMergeJoin [codegen id : 7]
Left keys [2]: [customer_sk#7, item_sk#8]
Right keys [2]: [customer_sk#14, item_sk#15]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
index 227b3c6..99c8a1d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt
@@ -5,8 +5,8 @@ WholeStageCodegen (8)
WholeStageCodegen (7)
HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum]
Project [customer_sk,customer_sk]
- InputAdapter
- SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+ SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
+ InputAdapter
WholeStageCodegen (3)
Sort [customer_sk,item_sk]
HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk]
@@ -29,6 +29,7 @@ WholeStageCodegen (8)
Scan parquet default.date_dim [d_date_sk,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk] #3
+ InputAdapter
WholeStageCodegen (6)
Sort [customer_sk,item_sk]
HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
index 740ea0f..64111ee 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt
@@ -10,7 +10,7 @@ TakeOrderedAndProject (70)
: +- Exchange (58)
: +- * Project (57)
: +- * Filter (56)
- : +- SortMergeJoin FullOuter (55)
+ : +- * SortMergeJoin FullOuter (55)
: :- * Sort (27)
: : +- Exchange (26)
: : +- * HashAggregate (25)
@@ -317,7 +317,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#38, d_date#33, cume_sales#54]
Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0
-(55) SortMergeJoin
+(55) SortMergeJoin [codegen id : 29]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#38, d_date#33]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
index ed52e97..1a89b7c 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt
@@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
WholeStageCodegen (29)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
Filter [item_sk,item_sk]
- InputAdapter
- SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ InputAdapter
WholeStageCodegen (14)
Sort [item_sk,d_date]
InputAdapter
@@ -73,6 +73,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Sort [ws_item_sk,d_date]
InputAdapter
ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4
+ InputAdapter
WholeStageCodegen (28)
Sort [item_sk,d_date]
InputAdapter
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
index cf86cd6..9edb377 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt
@@ -10,7 +10,7 @@ TakeOrderedAndProject (67)
: +- Exchange (54)
: +- * Project (53)
: +- * Filter (52)
- : +- SortMergeJoin FullOuter (51)
+ : +- * SortMergeJoin FullOuter (51)
: :- * Sort (25)
: : +- Exchange (24)
: : +- * HashAggregate (23)
@@ -298,7 +298,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#38, d_date#33, cume_sales#54]
Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0
-(51) SortMergeJoin
+(51) SortMergeJoin [codegen id : 25]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#38, d_date#33]
Join condition: None
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
index a99caa4..d6612db 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt
@@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
WholeStageCodegen (25)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
Filter [item_sk,item_sk]
- InputAdapter
- SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ SortMergeJoin [item_sk,d_date,item_sk,d_date]
+ InputAdapter
WholeStageCodegen (12)
Sort [item_sk,d_date]
InputAdapter
@@ -67,6 +67,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Sort [ws_item_sk,d_date]
InputAdapter
ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4
+ InputAdapter
WholeStageCodegen (24)
Sort [item_sk,d_date]
InputAdapter
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f483971..00ea371 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -171,48 +171,54 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
}
- test("Full Outer ShuffledHashJoin should be included in WholeStageCodegen") {
+ test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") {
val df1 = spark.range(5).select($"id".as("k1"))
val df2 = spark.range(10).select($"id".as("k2"))
val df3 = spark.range(3).select($"id".as("k3"))
- // test one join with unique key from build side
- val joinUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer")
- assert(joinUniqueDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
- }.size === 1)
- checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
- Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
- assert(joinUniqueDF.count() === 10)
-
- // test one join with non-unique key from build side
- val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3, "full_outer")
- assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
- }.size === 1)
- checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
- Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
-
- // test one join with non-equi condition
- val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"),
- $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
- assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
- }.size === 1)
- checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
- Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
- Row(null, 5)))
-
- // test two joins
- val twoJoinsDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer")
- .join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer")
- assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
- }.size === 2)
- checkAnswer(twoJoinsDF,
- Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
- Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null),
- Row(null, 9, null), Row(null, null, 1)))
+ Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint =>
+ // test one join with unique key from build side
+ val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+ assert(joinUniqueDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
+ Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
+ assert(joinUniqueDF.count() === 10)
+
+ // test one join with non-unique key from build side
+ val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer")
+ assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
+ Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
+
+ // test one join with non-equi condition
+ val joinWithNonEquiDF = df1.join(df2.hint(hint),
+ $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
+ assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
+ Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
+ Row(null, 5)))
+
+ // test two joins
+ val twoJoinsDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+ .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer")
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF,
+ Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
+ Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null),
+ Row(null, 9, null), Row(null, null, 1)))
+ }
}
test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org