You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/19 13:49:11 UTC

spark git commit: [SPARK-21441][SQL] Incorrect Codegen in SortMergeJoinExec results failures in some cases

Repository: spark
Updated Branches:
  refs/heads/master 4eb081cc8 -> 6b6dd682e


[SPARK-21441][SQL] Incorrect Codegen in SortMergeJoinExec results failures in some cases

## What changes were proposed in this pull request?

https://issues.apache.org/jira/projects/SPARK/issues/SPARK-21441

This issue can be reproduced by the following example:

```
val spark = SparkSession
   .builder()
   .appName("smj-codegen")
   .master("local")
   .config("spark.sql.autoBroadcastJoinThreshold", "1")
   .getOrCreate()
val df1 = spark.createDataFrame(Seq((1, 1), (2, 2), (3, 3))).toDF("key", "int")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "3"))).toDF("key", "str")
val df = df1.join(df2, df1("key") === df2("key"))
   .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1")
   .select("int")
   df.show()
```

To conclude, the issue happens when:
(1) SortMergeJoin condition contains CodegenFallback expressions.
(2) In PhysicalPlan tree, SortMergeJoin node  is the child of root node, e.g., the Project in above example.

This patch fixes the logic in `CollapseCodegenStages` rule.

## How was this patch tested?
Unit test and manual verification in our cluster.

Author: donnyzone <we...@gmail.com>

Closes #18656 from DonnyZone/Fix_SortMergeJoinExec.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b6dd682
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b6dd682
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b6dd682

Branch: refs/heads/master
Commit: 6b6dd682e84d3b03d0b15fbd81a0d16729e521d2
Parents: 4eb081c
Author: donnyzone <we...@gmail.com>
Authored: Wed Jul 19 21:48:54 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jul 19 21:48:54 2017 +0800

----------------------------------------------------------------------
 .../sql/execution/WholeStageCodegenExec.scala   |  8 +++----
 .../sql/execution/WholeStageCodegenSuite.scala  | 22 ++++++++++++++++++++
 2 files changed, 26 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6b6dd682/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 1007a7d..34134db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
    * Inserts an InputAdapter on top of those that do not support codegen.
    */
   private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
-    case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen =>
-      // The children of SortMergeJoin should do codegen separately.
-      j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
-        right = InputAdapter(insertWholeStageCodegen(right)))
     case p if !supportCodegen(p) =>
       // collapse them recursively
       InputAdapter(insertWholeStageCodegen(p))
+    case j @ SortMergeJoinExec(_, _, _, _, left, right) =>
+      // The children of SortMergeJoin should do codegen separately.
+      j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+        right = InputAdapter(insertWholeStageCodegen(right)))
     case p =>
       p.withNewChildren(p.children.map(insertInputAdapter))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6b6dd682/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index a4b30a2..183c68f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions.{avg, broadcast, col, max}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
 
@@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         "named_struct('a',id+2, 'b',id+2) as col2")
       .filter("col1 = col2").count()
   }
+
+  test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
+      import testImplicits._
+
+      val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int")
+      val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str")
+
+      val df = df1.join(df2, df1("key") === df2("key"))
+        .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1")
+        .select("int")
+
+      val plan = df.queryExecution.executedPlan
+      assert(!plan.find(p =>
+        p.isInstanceOf[WholeStageCodegenExec] &&
+          p.asInstanceOf[WholeStageCodegenExec].child.children(0)
+            .isInstanceOf[SortMergeJoinExec]).isDefined)
+      assert(df.collect() === Array(Row(1), Row(2)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org