You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/06 14:10:49 UTC

spark git commit: [SPARK-22445][SQL] move CodegenContext.copyResult to CodegenSupport

Repository: spark
Updated Branches:
  refs/heads/master 4bacddb60 -> 472db58cb


[SPARK-22445][SQL] move CodegenContext.copyResult to CodegenSupport

## What changes were proposed in this pull request?

`CodegenContext.copyResult` is kind of a global status for whole stage codegen. But the tricky part is, it is only used to transfer an information from child to parent when calling the `consume` chain. We have to be super careful in `produce`/`consume`, to set it to true when producing multiple result rows, and set it to false in operators that start new pipeline(like sort).

This PR moves the `copyResult` to `CodegenSupport`, and call it at `WholeStageCodegenExec`. This is much easier to reason about.

## How was this patch tested?

existing tests

Author: Wenchen Fan <we...@databricks.com>

Closes #19656 from cloud-fan/whole-sage.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/472db58c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/472db58c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/472db58c

Branch: refs/heads/master
Commit: 472db58cb19bbd3025eabbd185d920aab0ebb4da
Parents: 4bacddb
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Nov 6 15:10:44 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Nov 6 15:10:44 2017 +0100

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     | 10 ------
 .../spark/sql/execution/ColumnarBatchScan.scala |  2 +-
 .../apache/spark/sql/execution/ExpandExec.scala |  3 +-
 .../spark/sql/execution/GenerateExec.scala      |  3 +-
 .../apache/spark/sql/execution/SortExec.scala   | 14 ++++----
 .../sql/execution/WholeStageCodegenExec.scala   | 35 +++++++++++++++-----
 .../execution/aggregate/HashAggregateExec.scala | 14 ++++----
 .../sql/execution/basicPhysicalOperators.scala  |  5 +--
 .../execution/joins/BroadcastHashJoinExec.scala | 16 +++++++--
 .../sql/execution/joins/SortMergeJoinExec.scala |  3 +-
 10 files changed, 66 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 58738b5..98eda2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -140,16 +140,6 @@ class CodegenContext {
   var currentVars: Seq[ExprCode] = null
 
   /**
-   * Whether should we copy the result rows or not.
-   *
-   * If any operator inside WholeStageCodegen generate multiple rows from a single row (for
-   * example, Join), this should be true.
-   *
-   * If an operator starts a new pipeline, this should be reset to false before calling `consume()`.
-   */
-  var copyResult: Boolean = false
-
-  /**
    * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
    * 3-tuple: java type, variable name, code to init it.
    * As an example, ("int", "count", "count = 0;") will produce code:

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index eb01e12..1925bad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -115,7 +115,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
     val localIdx = ctx.freshName("localIdx")
     val localEnd = ctx.freshName("localEnd")
     val numRows = ctx.freshName("numRows")
-    val shouldStop = if (isShouldStopRequired) {
+    val shouldStop = if (parent.needStopCheck) {
       s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
     } else {
       "// shouldStop check is eliminated"

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index d5603b3..33849f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -93,6 +93,8 @@ case class ExpandExec(
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
+  override def needCopyResult: Boolean = true
+
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     /*
      * When the projections list looks like:
@@ -187,7 +189,6 @@ case class ExpandExec(
     val i = ctx.freshName("i")
     // these column have to declared before the loop.
     val evaluate = evaluateVariables(outputColumns)
-    ctx.copyResult = true
     s"""
        |$evaluate
        |for (int $i = 0; $i < ${projections.length}; $i ++) {

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 65ca374..c142d3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -132,9 +132,10 @@ case class GenerateExec(
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
+  override def needCopyResult: Boolean = true
+
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     ctx.currentVars = input
-    ctx.copyResult = true
 
     // Add input rows to the values when we are joining
     val values = if (join) {

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index ff71fd4..21765cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -124,6 +124,14 @@ case class SortExec(
   // Name of sorter variable used in codegen.
   private var sorterVariable: String = _
 
+  // The result rows come from the sort buffer, so this operator doesn't need to copy its result
+  // even if its child does.
+  override def needCopyResult: Boolean = false
+
+  // Sort operator always consumes all the input rows before outputting any result, so we don't need
+  // a stop check before sorting.
+  override def needStopCheck: Boolean = false
+
   override protected def doProduce(ctx: CodegenContext): String = {
     val needToSort = ctx.freshName("needToSort")
     ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -148,10 +156,6 @@ case class SortExec(
         | }
       """.stripMargin.trim)
 
-    // The child could change `copyResult` to true, but we had already consumed all the rows,
-    // so `copyResult` should be reset to `false`.
-    ctx.copyResult = false
-
     val outputRow = ctx.freshName("outputRow")
     val peakMemory = metricTerm(ctx, "peakMemory")
     val spillSize = metricTerm(ctx, "spillSize")
@@ -177,8 +181,6 @@ case class SortExec(
      """.stripMargin.trim
   }
 
-  protected override val shouldStopRequired = false
-
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     s"""
        |${row.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 286cb3b..16b5706 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -213,19 +213,32 @@ trait CodegenSupport extends SparkPlan {
   }
 
   /**
-   * For optimization to suppress shouldStop() in a loop of WholeStageCodegen.
-   * Returning true means we need to insert shouldStop() into the loop producing rows, if any.
+   * Whether or not the result rows of this operator should be copied before putting into a buffer.
+   *
+   * If any operator inside WholeStageCodegen generate multiple rows from a single row (for
+   * example, Join), this should be true.
+   *
+   * If an operator starts a new pipeline, this should be false.
    */
-  def isShouldStopRequired: Boolean = {
-    return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired)
+  def needCopyResult: Boolean = {
+    if (children.isEmpty) {
+      false
+    } else if (children.length == 1) {
+      children.head.asInstanceOf[CodegenSupport].needCopyResult
+    } else {
+      throw new UnsupportedOperationException
+    }
   }
 
   /**
-   * Set to false if this plan consumes all rows produced by children but doesn't output row
-   * to buffer by calling append(), so the children don't require shouldStop()
-   * in the loop of producing rows.
+   * Whether or not the children of this operator should generate a stop check when consuming input
+   * rows. This is used to suppress shouldStop() in a loop of WholeStageCodegen.
+   *
+   * This should be false if an operator starts a new pipeline, which means it consumes all rows
+   * produced by children but doesn't output row to buffer by calling append(),  so the children
+   * don't require shouldStop() in the loop of producing rows.
    */
-  protected def shouldStopRequired: Boolean = true
+  def needStopCheck: Boolean = parent.needStopCheck
 }
 
 
@@ -278,6 +291,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
       addSuffix: Boolean = false): StringBuilder = {
     child.generateTreeString(depth, lastChildren, builder, verbose, "")
   }
+
+  override def needCopyResult: Boolean = false
 }
 
 object WholeStageCodegenExec {
@@ -467,7 +482,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
-    val doCopy = if (ctx.copyResult) {
+    val doCopy = if (needCopyResult) {
       ".copy()"
     } else {
       ""
@@ -487,6 +502,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
       addSuffix: Boolean = false): StringBuilder = {
     child.generateTreeString(depth, lastChildren, builder, verbose, "*")
   }
+
+  override def needStopCheck: Boolean = true
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 43e5ff8..2a208a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -149,6 +149,14 @@ case class HashAggregateExec(
     child.asInstanceOf[CodegenSupport].inputRDDs()
   }
 
+  // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this
+  // operator doesn't need to copy its result even if its child does.
+  override def needCopyResult: Boolean = false
+
+  // Aggregate operator always consumes all the input rows before outputting any result, so we
+  // don't need a stop check before aggregating.
+  override def needStopCheck: Boolean = false
+
   protected override def doProduce(ctx: CodegenContext): String = {
     if (groupingExpressions.isEmpty) {
       doProduceWithoutKeys(ctx)
@@ -246,8 +254,6 @@ case class HashAggregateExec(
      """.stripMargin
   }
 
-  protected override val shouldStopRequired = false
-
   private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // only have DeclarativeAggregate
     val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
@@ -651,10 +657,6 @@ case class HashAggregateExec(
     val outputFunc = generateResultFunction(ctx)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
-    // The child could change `copyResult` to true, but we had already consumed all the rows,
-    // so `copyResult` should be reset to `false`.
-    ctx.copyResult = false
-
     def outputFromGeneratedMap: String = {
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index e58c3ce..3c7daa0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -279,6 +279,8 @@ case class SampleExec(
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
+  override def needCopyResult: Boolean = withReplacement
+
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
     val sampler = ctx.freshName("sampler")
@@ -286,7 +288,6 @@ case class SampleExec(
     if (withReplacement) {
       val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
       val initSampler = ctx.freshName("initSampler")
-      ctx.copyResult = true
 
       val initSamplerFuncName = ctx.addNewFunction(initSampler,
         s"""
@@ -450,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
     val localIdx = ctx.freshName("localIdx")
     val localEnd = ctx.freshName("localEnd")
     val range = ctx.freshName("range")
-    val shouldStop = if (isShouldStopRequired) {
+    val shouldStop = if (parent.needStopCheck) {
       s"if (shouldStop()) { $number = $value + ${step}L; return; }"
     } else {
       "// shouldStop check is eliminated"

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b09da9b..837b852 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -76,6 +76,20 @@ case class BroadcastHashJoinExec(
     streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
   }
 
+  override def needCopyResult: Boolean = joinType match {
+    case _: InnerLike | LeftOuter | RightOuter =>
+      // For inner and outer joins, one row from the streamed side may produce multiple result rows,
+      // if the build side has duplicated keys. Then we need to copy the result rows before putting
+      // them in a buffer, because these result rows share one UnsafeRow instance. Note that here
+      // we wait for the broadcast to be finished, which is a no-op because it's already finished
+      // when we wait it in `doProduce`.
+      !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique
+
+    // Other joins types(semi, anti, existence) can at most produce one result row for one input
+    // row from the streamed side, so no need to copy the result rows.
+    case _ => false
+  }
+
   override def doProduce(ctx: CodegenContext): String = {
     streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
@@ -237,7 +251,6 @@ case class BroadcastHashJoinExec(
        """.stripMargin
 
     } else {
-      ctx.copyResult = true
       val matches = ctx.freshName("matches")
       val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       s"""
@@ -310,7 +323,6 @@ case class BroadcastHashJoinExec(
        """.stripMargin
 
     } else {
-      ctx.copyResult = true
       val matches = ctx.freshName("matches")
       val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       val found = ctx.freshName("found")

http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 4e02803..cf7885f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -569,8 +569,9 @@ case class SortMergeJoinExec(
     }
   }
 
+  override def needCopyResult: Boolean = true
+
   override def doProduce(ctx: CodegenContext): String = {
-    ctx.copyResult = true
     val leftInput = ctx.freshName("leftInput")
     ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
     val rightInput = ctx.freshName("rightInput")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org