You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/06 14:10:49 UTC
spark git commit: [SPARK-22445][SQL] move CodegenContext.copyResult
to CodegenSupport
Repository: spark
Updated Branches:
refs/heads/master 4bacddb60 -> 472db58cb
[SPARK-22445][SQL] move CodegenContext.copyResult to CodegenSupport
## What changes were proposed in this pull request?
`CodegenContext.copyResult` is kind of a global status for whole stage codegen. But the tricky part is, it is only used to transfer an information from child to parent when calling the `consume` chain. We have to be super careful in `produce`/`consume`, to set it to true when producing multiple result rows, and set it to false in operators that start new pipeline(like sort).
This PR moves the `copyResult` to `CodegenSupport`, and call it at `WholeStageCodegenExec`. This is much easier to reason about.
## How was this patch tested?
existing tests
Author: Wenchen Fan <we...@databricks.com>
Closes #19656 from cloud-fan/whole-sage.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/472db58c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/472db58c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/472db58c
Branch: refs/heads/master
Commit: 472db58cb19bbd3025eabbd185d920aab0ebb4da
Parents: 4bacddb
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Nov 6 15:10:44 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Nov 6 15:10:44 2017 +0100
----------------------------------------------------------------------
.../expressions/codegen/CodeGenerator.scala | 10 ------
.../spark/sql/execution/ColumnarBatchScan.scala | 2 +-
.../apache/spark/sql/execution/ExpandExec.scala | 3 +-
.../spark/sql/execution/GenerateExec.scala | 3 +-
.../apache/spark/sql/execution/SortExec.scala | 14 ++++----
.../sql/execution/WholeStageCodegenExec.scala | 35 +++++++++++++++-----
.../execution/aggregate/HashAggregateExec.scala | 14 ++++----
.../sql/execution/basicPhysicalOperators.scala | 5 +--
.../execution/joins/BroadcastHashJoinExec.scala | 16 +++++++--
.../sql/execution/joins/SortMergeJoinExec.scala | 3 +-
10 files changed, 66 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 58738b5..98eda2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -140,16 +140,6 @@ class CodegenContext {
var currentVars: Seq[ExprCode] = null
/**
- * Whether should we copy the result rows or not.
- *
- * If any operator inside WholeStageCodegen generate multiple rows from a single row (for
- * example, Join), this should be true.
- *
- * If an operator starts a new pipeline, this should be reset to false before calling `consume()`.
- */
- var copyResult: Boolean = false
-
- /**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
* As an example, ("int", "count", "count = 0;") will produce code:
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index eb01e12..1925bad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -115,7 +115,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val numRows = ctx.freshName("numRows")
- val shouldStop = if (isShouldStopRequired) {
+ val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
} else {
"// shouldStop check is eliminated"
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index d5603b3..33849f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -93,6 +93,8 @@ case class ExpandExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
+ override def needCopyResult: Boolean = true
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
/*
* When the projections list looks like:
@@ -187,7 +189,6 @@ case class ExpandExec(
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
- ctx.copyResult = true
s"""
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 65ca374..c142d3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -132,9 +132,10 @@ case class GenerateExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
+ override def needCopyResult: Boolean = true
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
ctx.currentVars = input
- ctx.copyResult = true
// Add input rows to the values when we are joining
val values = if (join) {
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index ff71fd4..21765cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -124,6 +124,14 @@ case class SortExec(
// Name of sorter variable used in codegen.
private var sorterVariable: String = _
+ // The result rows come from the sort buffer, so this operator doesn't need to copy its result
+ // even if its child does.
+ override def needCopyResult: Boolean = false
+
+ // Sort operator always consumes all the input rows before outputting any result, so we don't need
+ // a stop check before sorting.
+ override def needStopCheck: Boolean = false
+
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -148,10 +156,6 @@ case class SortExec(
| }
""".stripMargin.trim)
- // The child could change `copyResult` to true, but we had already consumed all the rows,
- // so `copyResult` should be reset to `false`.
- ctx.copyResult = false
-
val outputRow = ctx.freshName("outputRow")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
@@ -177,8 +181,6 @@ case class SortExec(
""".stripMargin.trim
}
- protected override val shouldStopRequired = false
-
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
s"""
|${row.code}
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 286cb3b..16b5706 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -213,19 +213,32 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * For optimization to suppress shouldStop() in a loop of WholeStageCodegen.
- * Returning true means we need to insert shouldStop() into the loop producing rows, if any.
+ * Whether or not the result rows of this operator should be copied before putting into a buffer.
+ *
+ * If any operator inside WholeStageCodegen generate multiple rows from a single row (for
+ * example, Join), this should be true.
+ *
+ * If an operator starts a new pipeline, this should be false.
*/
- def isShouldStopRequired: Boolean = {
- return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired)
+ def needCopyResult: Boolean = {
+ if (children.isEmpty) {
+ false
+ } else if (children.length == 1) {
+ children.head.asInstanceOf[CodegenSupport].needCopyResult
+ } else {
+ throw new UnsupportedOperationException
+ }
}
/**
- * Set to false if this plan consumes all rows produced by children but doesn't output row
- * to buffer by calling append(), so the children don't require shouldStop()
- * in the loop of producing rows.
+ * Whether or not the children of this operator should generate a stop check when consuming input
+ * rows. This is used to suppress shouldStop() in a loop of WholeStageCodegen.
+ *
+ * This should be false if an operator starts a new pipeline, which means it consumes all rows
+ * produced by children but doesn't output row to buffer by calling append(), so the children
+ * don't require shouldStop() in the loop of producing rows.
*/
- protected def shouldStopRequired: Boolean = true
+ def needStopCheck: Boolean = parent.needStopCheck
}
@@ -278,6 +291,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
addSuffix: Boolean = false): StringBuilder = {
child.generateTreeString(depth, lastChildren, builder, verbose, "")
}
+
+ override def needCopyResult: Boolean = false
}
object WholeStageCodegenExec {
@@ -467,7 +482,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val doCopy = if (ctx.copyResult) {
+ val doCopy = if (needCopyResult) {
".copy()"
} else {
""
@@ -487,6 +502,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
addSuffix: Boolean = false): StringBuilder = {
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
}
+
+ override def needStopCheck: Boolean = true
}
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 43e5ff8..2a208a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -149,6 +149,14 @@ case class HashAggregateExec(
child.asInstanceOf[CodegenSupport].inputRDDs()
}
+ // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this
+ // operator doesn't need to copy its result even if its child does.
+ override def needCopyResult: Boolean = false
+
+ // Aggregate operator always consumes all the input rows before outputting any result, so we
+ // don't need a stop check before aggregating.
+ override def needStopCheck: Boolean = false
+
protected override def doProduce(ctx: CodegenContext): String = {
if (groupingExpressions.isEmpty) {
doProduceWithoutKeys(ctx)
@@ -246,8 +254,6 @@ case class HashAggregateExec(
""".stripMargin
}
- protected override val shouldStopRequired = false
-
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
@@ -651,10 +657,6 @@ case class HashAggregateExec(
val outputFunc = generateResultFunction(ctx)
val numOutput = metricTerm(ctx, "numOutputRows")
- // The child could change `copyResult` to true, but we had already consumed all the rows,
- // so `copyResult` should be reset to `false`.
- ctx.copyResult = false
-
def outputFromGeneratedMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index e58c3ce..3c7daa0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -279,6 +279,8 @@ case class SampleExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
+ override def needCopyResult: Boolean = withReplacement
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val sampler = ctx.freshName("sampler")
@@ -286,7 +288,6 @@ case class SampleExec(
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
- ctx.copyResult = true
val initSamplerFuncName = ctx.addNewFunction(initSampler,
s"""
@@ -450,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val range = ctx.freshName("range")
- val shouldStop = if (isShouldStopRequired) {
+ val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $number = $value + ${step}L; return; }"
} else {
"// shouldStop check is eliminated"
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b09da9b..837b852 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -76,6 +76,20 @@ case class BroadcastHashJoinExec(
streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
}
+ override def needCopyResult: Boolean = joinType match {
+ case _: InnerLike | LeftOuter | RightOuter =>
+ // For inner and outer joins, one row from the streamed side may produce multiple result rows,
+ // if the build side has duplicated keys. Then we need to copy the result rows before putting
+ // them in a buffer, because these result rows share one UnsafeRow instance. Note that here
+ // we wait for the broadcast to be finished, which is a no-op because it's already finished
+ // when we wait it in `doProduce`.
+ !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique
+
+ // Other joins types(semi, anti, existence) can at most produce one result row for one input
+ // row from the streamed side, so no need to copy the result rows.
+ case _ => false
+ }
+
override def doProduce(ctx: CodegenContext): String = {
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
}
@@ -237,7 +251,6 @@ case class BroadcastHashJoinExec(
""".stripMargin
} else {
- ctx.copyResult = true
val matches = ctx.freshName("matches")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
s"""
@@ -310,7 +323,6 @@ case class BroadcastHashJoinExec(
""".stripMargin
} else {
- ctx.copyResult = true
val matches = ctx.freshName("matches")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val found = ctx.freshName("found")
http://git-wip-us.apache.org/repos/asf/spark/blob/472db58c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 4e02803..cf7885f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -569,8 +569,9 @@ case class SortMergeJoinExec(
}
}
+ override def needCopyResult: Boolean = true
+
override def doProduce(ctx: CodegenContext): String = {
- ctx.copyResult = true
val leftInput = ctx.freshName("leftInput")
ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
val rightInput = ctx.freshName("rightInput")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org