You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/01/26 00:11:37 UTC

spark git commit: [SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec

Repository: spark
Updated Branches:
  refs/heads/master 8480c0c57 -> e57f39481


[SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec

## What changes were proposed in this pull request?

**Proposal**

Add a per-query ID to the codegen stages as represented by `WholeStageCodegenExec` operators. This ID will be used in
-  the explain output of the physical plan, and in
- the generated class name.

Specifically, this ID will be stable within a query, counting up from 1 in depth-first post-order for all the `WholeStageCodegenExec` inserted into a plan.
The ID value 0 is reserved for "free-floating" `WholeStageCodegenExec` objects, which may have been created for one-off purposes, e.g. for fallback handling of codegen stages that failed to codegen the whole stage and wishes to codegen a subset of the children operators (as seen in `org.apache.spark.sql.execution.FileSourceScanExec#doExecute`).

Example: for the following query:
```scala
scala> spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 1)

scala> val df1 = spark.range(10).select('id as 'x, 'id + 1 as 'y).orderBy('x).select('x + 1 as 'z, 'y)
df1: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint]

scala> val df2 = spark.range(5)
df2: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> val query = df1.join(df2, 'z === 'id)
query: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint ... 1 more field]
```

The explain output before the change is:
```scala
scala> query.explain
== Physical Plan ==
*SortMergeJoin [z#9L], [id#13L], Inner
:- *Sort [z#9L ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(z#9L, 200)
:     +- *Project [(x#3L + 1) AS z#9L, y#4L]
:        +- *Sort [x#3L ASC NULLS FIRST], true, 0
:           +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200)
:              +- *Project [id#0L AS x#3L, (id#0L + 1) AS y#4L]
:                 +- *Range (0, 10, step=1, splits=8)
+- *Sort [id#13L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#13L, 200)
      +- *Range (0, 5, step=1, splits=8)
```
Note how codegen'd operators are annotated with a prefix `"*"`. See how the `SortMergeJoin` operator and its direct children `Sort` operators are adjacent and all annotated with the `"*"`, so it's hard to tell they're actually in separate codegen stages.

and after this change it'll be:
```scala
scala> query.explain
== Physical Plan ==
*(6) SortMergeJoin [z#9L], [id#13L], Inner
:- *(3) Sort [z#9L ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(z#9L, 200)
:     +- *(2) Project [(x#3L + 1) AS z#9L, y#4L]
:        +- *(2) Sort [x#3L ASC NULLS FIRST], true, 0
:           +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200)
:              +- *(1) Project [id#0L AS x#3L, (id#0L + 1) AS y#4L]
:                 +- *(1) Range (0, 10, step=1, splits=8)
+- *(5) Sort [id#13L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#13L, 200)
      +- *(4) Range (0, 5, step=1, splits=8)
```
Note that the annotated prefix becomes `"*(id) "`. See how the `SortMergeJoin` operator and its direct children `Sort` operators have different codegen stage IDs.

It'll also show up in the name of the generated class, as a suffix in the format of `GeneratedClass$GeneratedIterator$id`.

For example, note how `GeneratedClass$GeneratedIteratorForCodegenStage3` and `GeneratedClass$GeneratedIteratorForCodegenStage6` in the following stack trace corresponds to the IDs shown in the explain output above:
```
"Executor task launch worker for task 42412957" daemon prio=5 tid=0x58 nid=NA runnable
  java.lang.Thread.State: RUNNABLE
	  at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:109)
	  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.sort_addToSorter$(generated.java:32)
	  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(generated.java:41)
	  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$9$$anon$1.hasNext(WholeStageCodegenExec.scala:494)
	  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.findNextInnerJoinRows$(generated.java:42)
	  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.processNext(generated.java:101)
	  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$2.hasNext(WholeStageCodegenExec.scala:513)
	  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828)
	  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828)
	  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	  at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	  at org.apache.spark.scheduler.Task.run(Task.scala:109)
	  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	  at java.lang.Thread.run(Thread.java:748)
```

**Rationale**

Right now, the codegen from Spark SQL lacks the means to differentiate between a couple of things:

1. It's hard to tell which physical operators are in the same WholeStageCodegen stage. Note that this "stage" is a separate notion from Spark's RDD execution stages; this one is only to delineate codegen units.
There can be adjacent physical operators that are both codegen'd but are in separate codegen stages. Some of this is due to hacky implementation details, such as the case with `SortMergeJoin` and its `Sort` inputs -- they're hard coded to be split into separate stages although both are codegen'd.
When printing out the explain output of the physical plan, you'd only see the codegen'd physical operators annotated with a preceding star (`'*'`) but would have no way to figure out if they're in the same stage.

2. Performance/error diagnosis
The generated code has class/method names that are hard to differentiate between queries or even between codegen stages within the same query. If we use a Java-level profiler to collect profiles, or if we encounter a Java-level exception with a stack trace in it, it's really hard to tell which part of a query it's at.
By introducing a per-query codegen stage ID, we'd at least be able to know which codegen stage (and in turn, which group of physical operators) was a profile tick or an exception happened.

The reason why this proposal uses a per-query ID is because it's stable within a query, so that multiple runs of the same query will see the same resulting IDs. This both benefits understandability for users, and also it plays well with the codegen cache in Spark SQL which uses the generated source code as the key.

The downside to using per-query IDs as opposed to a per-session or globally incrementing ID is of course we can't tell apart different query runs with this ID alone. But for now I believe this is a good enough tradeoff.

## How was this patch tested?

Existing tests. This PR does not involve any runtime behavior changes other than some name changes.
The SQL query test suites that compares explain outputs have been updates to ignore the newly added `codegenStageId`.

Author: Kris Mok <kr...@databricks.com>

Closes #20224 from rednaxelafx/wsc-codegenstageid.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e57f3948
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e57f3948
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e57f3948

Branch: refs/heads/master
Commit: e57f394818b0a62f99609e1032fede7e981f306f
Parents: 8480c0c
Author: Kris Mok <kr...@databricks.com>
Authored: Thu Jan 25 16:11:33 2018 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Jan 25 16:11:33 2018 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/internal/SQLConf.scala | 10 +++
 .../sql/execution/DataSourceScanExec.scala      |  2 +-
 .../sql/execution/WholeStageCodegenExec.scala   | 85 +++++++++++++++++---
 .../columnar/InMemoryTableScanExec.scala        |  2 +-
 .../datasources/v2/DataSourceV2ScanExec.scala   |  2 +-
 .../apache/spark/sql/SQLQueryTestSuite.scala    |  3 +-
 .../sql/execution/WholeStageCodegenSuite.scala  | 34 ++++++++
 .../columnar/InMemoryColumnarQuerySuite.scala   |  2 +-
 .../sql/hive/execution/HiveExplainSuite.scala   | 39 +++++++--
 9 files changed, 158 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 470f88c..b0d18b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -629,6 +629,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME =
+    buildConf("spark.sql.codegen.useIdInClassName")
+    .internal()
+    .doc("When true, embed the (whole-stage) codegen stage ID into " +
+      "the class name of the generated class as a suffix")
+    .booleanConf
+    .createWithDefault(true)
+
   val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields")
     .internal()
     .doc("The maximum number of fields (including nested fields) that will be supported before" +
@@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging {
 
   def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
 
+  def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME)
+
   def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
 
   def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 7c7d79c..aa66ee7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -324,7 +324,7 @@ case class FileSourceScanExec(
       // in the case of fallback, this batched scan should never fail because of:
       // 1) only primitive types are supported
       // 2) the number of columns should be smaller than spark.sql.codegen.maxFields
-      WholeStageCodegenExec(this).execute()
+      WholeStageCodegenExec(this)(codegenStageId = 0).execute()
     } else {
       val unsafeRows = {
         val scan = inputRDD

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 8ea9e81..0e525b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution
 
 import java.util.Locale
+import java.util.function.Supplier
 
 import scala.collection.mutable
 
@@ -414,6 +415,58 @@ object WholeStageCodegenExec {
   }
 }
 
+object WholeStageCodegenId {
+  // codegenStageId: ID for codegen stages within a query plan.
+  // It does not affect equality, nor does it participate in destructuring pattern matching
+  // of WholeStageCodegenExec.
+  //
+  // This ID is used to help differentiate between codegen stages. It is included as a part
+  // of the explain output for physical plans, e.g.
+  //
+  // == Physical Plan ==
+  // *(5) SortMergeJoin [x#3L], [y#9L], Inner
+  // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0
+  // :  +- Exchange hashpartitioning(x#3L, 200)
+  // :     +- *(1) Project [(id#0L % 2) AS x#3L]
+  // :        +- *(1) Filter isnotnull((id#0L % 2))
+  // :           +- *(1) Range (0, 5, step=1, splits=8)
+  // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0
+  //    +- Exchange hashpartitioning(y#9L, 200)
+  //       +- *(3) Project [(id#6L % 2) AS y#9L]
+  //          +- *(3) Filter isnotnull((id#6L % 2))
+  //             +- *(3) Range (0, 5, step=1, splits=8)
+  //
+  // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the
+  // same codegen stage.
+  //
+  // The codegen stage ID is also optionally included in the name of the generated classes as
+  // a suffix, so that it's easier to associate a generated class back to the physical operator.
+  // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName
+  //
+  // The ID is also included in various log messages.
+  //
+  // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order".
+  // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order.
+  // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order.
+  //
+  // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object
+  // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec
+  // failed to generate/compile code.
+
+  private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] {
+    override def get() = 1  // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+
+  })
+
+  def resetPerQuery(): Unit = codegenStageCounter.set(1)
+
+  def getNextStageId(): Int = {
+    val counter = codegenStageCounter
+    val id = counter.get()
+    counter.set(id + 1)
+    id
+  }
+}
+
 /**
  * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java
  * function.
@@ -442,7 +495,8 @@ object WholeStageCodegenExec {
  * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input,
  * used to generated code for [[BoundReference]].
  */
-case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
+    extends UnaryExecNode with CodegenSupport {
 
   override def output: Seq[Attribute] = child.output
 
@@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
     "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
       WholeStageCodegenExec.PIPELINE_DURATION_METRIC))
 
+  def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) {
+    s"GeneratedIteratorForCodegenStage$codegenStageId"
+  } else {
+    "GeneratedIterator"
+  }
+
   /**
    * Generates code for this subtree.
    *
@@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
         }
        """, inlineToOuterClass = true)
 
+    val className = generatedClassName()
+
     val source = s"""
       public Object generate(Object[] references) {
-        return new GeneratedIterator(references);
+        return new $className(references);
       }
 
-      ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")}
-      final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
+      ${ctx.registerComment(
+        s"""Codegend pipeline for stage (id=$codegenStageId)
+           |${this.treeString.trim}""".stripMargin)}
+      final class $className extends ${classOf[BufferedRowIterator].getName} {
 
         private Object[] references;
         private scala.collection.Iterator[] inputs;
         ${ctx.declareMutableStates()}
 
-        public GeneratedIterator(Object[] references) {
+        public $className(Object[] references) {
           this.references = references;
         }
 
@@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
     } catch {
       case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback =>
         // We should already saw the error message
-        logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
+        logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
         return child.execute()
     }
 
@@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
       logInfo(s"Found too long generated codes and JIT optimization might not work: " +
         s"the bytecode size ($maxCodeSize) is above the limit " +
         s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
-        s"for this plan. To avoid this, you can raise the limit " +
+        s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
         s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
       child match {
         // The fallback solution of batch file source scan still uses WholeStageCodegenExec
@@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
       verbose: Boolean,
       prefix: String = "",
       addSuffix: Boolean = false): StringBuilder = {
-    child.generateTreeString(depth, lastChildren, builder, verbose, "*")
+    child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ")
   }
 
   override def needStopCheck: Boolean = true
+
+  override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
 }
 
 
@@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
     case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
       plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
     case plan: CodegenSupport if supportCodegen(plan) =>
-      WholeStageCodegenExec(insertInputAdapter(plan))
+      WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
     case other =>
       other.withNewChildren(other.children.map(insertWholeStageCodegen))
   }
 
   def apply(plan: SparkPlan): SparkPlan = {
     if (conf.wholeStageEnabled) {
+      WholeStageCodegenId.resetPerQuery()
       insertWholeStageCodegen(plan)
     } else {
       plan

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 28b3875..c167f1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -274,7 +274,7 @@ case class InMemoryTableScanExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     if (supportsBatch) {
-      WholeStageCodegenExec(this).execute()
+      WholeStageCodegenExec(this)(codegenStageId = 0).execute()
     } else {
       inputRDD
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 69d871d..2c22239 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -88,7 +88,7 @@ case class DataSourceV2ScanExec(
 
   override protected def doExecute(): RDD[InternalRow] = {
     if (supportsBatch) {
-      WholeStageCodegenExec(this).execute()
+      WholeStageCodegenExec(this)(codegenStageId = 0).execute()
     } else {
       val numOutputRows = longMetric("numOutputRows")
       inputRDD.map { r =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 054ada5..beac969 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
         .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/")
         .replaceAll("Created By.*", s"Created By $notIncludedMsg")
         .replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
-        .replaceAll("Last Access.*", s"Last Access $notIncludedMsg"))
+        .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
+        .replaceAll("\\*\\(\\d+\\) ", "*"))  // remove the WholeStageCodegen codegenStageIds
 
       // If the output is not pre-sorted, sort it.
       if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 242bb48..28ad712 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.{QueryTest, Row, SaveMode}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
       }
     }
   }
+
+  test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") {
+    // test case adapted from DataFrameSuite to trigger ReuseExchange
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
+      val df = spark.range(100)
+      val join = df.join(df, "id")
+      val plan = join.queryExecution.executedPlan
+      assert(!plan.find(p =>
+        p.isInstanceOf[WholeStageCodegenExec] &&
+          p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined,
+        "codegen stage IDs should be preserved through ReuseExchange")
+      checkAnswer(join, df.toDF)
+    }
+  }
+
+  test("including codegen stage ID in generated class name should not regress codegen caching") {
+    import testImplicits._
+
+    withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
+      val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE
+
+      // the same query run twice should hit the codegen cache
+      spark.range(3).select('id + 2).collect
+      val after1 = bytecodeSizeHisto.getCount
+      spark.range(3).select('id + 2).collect
+      val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
+      // bytecodeSizeHisto's count is always monotonically increasing if new compilation to
+      // bytecode had occurred. If the count stayed the same that means we've got a cache hit.
+      assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected")
+
+      // a different query can result in codegen cache miss, that's by design
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index ff7c5e5..2280da9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
         assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
 
         val execPlan = if (enabled == "true") {
-          WholeStageCodegenExec(planBeforeFilter.head)
+          WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0)
         } else {
           planBeforeFilter.head
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/e57f3948/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index a4273de..f84d188 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
     }
   }
 
-  test("EXPLAIN CODEGEN command") {
-    checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
-      "WholeStageCodegen",
-      "Generated code:",
-      "/* 001 */ public Object generate(Object[] references) {",
-      "/* 002 */   return new GeneratedIterator(references);",
-      "/* 003 */ }"
+  test("explain output of physical plan should contain proper codegen stage ID") {
+    checkKeywordsExist(sql(
+      """
+        |EXPLAIN SELECT t1.id AS a, t2.id AS b FROM
+        |(SELECT * FROM range(3)) t1 JOIN
+        |(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3
+      """.stripMargin),
+      "== Physical Plan ==",
+      "*(2) Project ",
+      "+- *(2) BroadcastHashJoin ",
+      "   :- BroadcastExchange ",
+      "   :  +- *(1) Range ",
+      "   +- *(2) Range "
     )
+  }
+
+  test("EXPLAIN CODEGEN command") {
+    // the generated class name in this test should stay in sync with
+    //   org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName()
+    for ((useIdInClassName, expectedClassName) <- Seq(
+           ("true", "GeneratedIteratorForCodegenStage1"),
+           ("false", "GeneratedIterator"))) {
+      withSQLConf(
+          SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) {
+        checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
+          "WholeStageCodegen",
+          "Generated code:",
+           "/* 001 */ public Object generate(Object[] references) {",
+          s"/* 002 */   return new $expectedClassName(references);",
+           "/* 003 */ }"
+        )
+      }
+    }
 
     checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"),
       "== Physical Plan =="


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org