You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/20 07:55:15 UTC

spark git commit: [SPARK-15214][SQL] Code-generation for Generate

Repository: spark
Updated Branches:
  refs/heads/master a64f25d8b -> 7ca7a6352


[SPARK-15214][SQL] Code-generation for Generate

## What changes were proposed in this pull request?

This PR adds code generation to `Generate`. It supports two code paths:
- General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator.
- Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns.

### Benchmarks
I have added some benchmarks and it seems we can create a nice speedup for explode:
#### Environment
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
Intel(R) Core(TM) i7-4980HQ CPU  2.80GHz
```
#### Explode Array
##### Before
```
generate explode array:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode array wholestage off         7377 / 7607          2.3         439.7       1.0X
generate explode array wholestage on          6055 / 6086          2.8         360.9       1.2X
```
##### After
```
generate explode array:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode array wholestage off         7432 / 7696          2.3         443.0       1.0X
generate explode array wholestage on           631 /  646         26.6          37.6      11.8X
```
#### Explode Map
##### Before
```
generate explode map:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode map wholestage off         12792 / 12848          1.3         762.5       1.0X
generate explode map wholestage on          11181 / 11237          1.5         666.5       1.1X
```
##### After
```
generate explode map:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode map wholestage off         10949 / 10972          1.5         652.6       1.0X
generate explode map wholestage on             870 /  913         19.3          51.9      12.6X
```
#### Posexplode
##### Before
```
generate posexplode array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate posexplode array wholestage off      7547 / 7580          2.2         449.8       1.0X
generate posexplode array wholestage on       5786 / 5838          2.9         344.9       1.3X
```
##### After
```
generate posexplode array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate posexplode array wholestage off      7535 / 7548          2.2         449.1       1.0X
generate posexplode array wholestage on        620 /  624         27.1          37.0      12.1X
```
#### Inline
##### Before
```
generate inline array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate inline array wholestage off          6935 / 6978          2.4         413.3       1.0X
generate inline array wholestage on           6360 / 6400          2.6         379.1       1.1X
```
##### After
```
generate inline array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate inline array wholestage off          6940 / 6966          2.4         413.6       1.0X
generate inline array wholestage on           1002 / 1012         16.7          59.7       6.9X
```
#### Stack
##### Before
```
generate stack:                          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate stack wholestage off               12980 / 13104          1.3         773.7       1.0X
generate stack wholestage on                11566 / 11580          1.5         689.4       1.1X
```
##### After
```
generate stack:                          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate stack wholestage off               12875 / 12949          1.3         767.4       1.0X
generate stack wholestage on                   840 /  845         20.0          50.0      15.3X
```
## How was this patch tested?

Existing tests.

Author: Herman van Hovell <hv...@databricks.com>
Author: Herman van Hovell <hv...@questtec.nl>

Closes #13065 from hvanhovell/SPARK-15214.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7ca7a635
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7ca7a635
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7ca7a635

Branch: refs/heads/master
Commit: 7ca7a635242377634c302b7816ce60bd9c908527
Parents: a64f25d
Author: Herman van Hovell <hv...@databricks.com>
Authored: Sat Nov 19 23:55:09 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Nov 19 23:55:09 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/generators.scala   | 110 ++++++++--
 .../SubexpressionEliminationSuite.scala         |  16 +-
 .../spark/sql/execution/GenerateExec.scala      | 202 ++++++++++++++++++-
 .../spark/sql/GeneratorFunctionSuite.scala      |  34 ++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    |   7 -
 .../sql/execution/WholeStageCodegenSuite.scala  |  32 ++-
 .../sql/execution/benchmark/MiscBenchmark.scala |  99 ++++++++-
 7 files changed, 463 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d042bfb..6c38f49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
 import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.types._
 
@@ -60,6 +62,26 @@ trait Generator extends Expression {
    * rows can be made here.
    */
   def terminate(): TraversableOnce[InternalRow] = Nil
+
+  /**
+   * Check if this generator supports code generation.
+   */
+  def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
+}
+
+/**
+ * A collection producing [[Generator]]. This trait provides a different path for code generation,
+ * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
+ */
+trait CollectionGenerator extends Generator {
+  /** The position of an element within the collection should also be returned. */
+  def position: Boolean
+
+  /** Rows will be inlined during generation. */
+  def inline: Boolean
+
+  /** The type of the returned collection object. */
+  def collectionType: DataType = dataType
 }
 
 /**
@@ -77,7 +99,9 @@ case class UserDefinedGenerator(
   private def initializeConverters(): Unit = {
     inputRow = new InterpretedProjection(children)
     convertToScala = {
-      val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
+      val inputSchema = StructType(children.map { e =>
+        StructField(e.simpleString, e.dataType, nullable = true)
+      })
       CatalystTypeConverters.createToScalaConverter(inputSchema)
     }.asInstanceOf[InternalRow => Row]
   }
@@ -109,8 +133,7 @@ case class UserDefinedGenerator(
        1  2
        3  NULL
   """)
-case class Stack(children: Seq[Expression])
-    extends Expression with Generator with CodegenFallback {
+case class Stack(children: Seq[Expression]) extends Generator {
 
   private lazy val numRows = children.head.eval().asInstanceOf[Int]
   private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
@@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression])
       InternalRow(fields: _*)
     }
   }
+
+
+  /**
+   * Only support code generation when stack produces 50 rows or less.
+   */
+  override def supportCodegen: Boolean = numRows <= 50
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    // Rows - we write these into an array.
+    val rowData = ctx.freshName("rows")
+    ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
+    val values = children.tail
+    val dataTypes = values.take(numFields).map(_.dataType)
+    val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
+      val fields = Seq.tabulate(numFields) { col =>
+        val index = row * numFields + col
+        if (index < values.length) values(index) else Literal(null, dataTypes(col))
+      }
+      val eval = CreateStruct(fields).genCode(ctx)
+      s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
+    })
+
+    // Create the collection.
+    val wrapperClass = classOf[mutable.WrappedArray[_]].getName
+    ctx.addMutableState(
+      s"$wrapperClass<InternalRow>",
+      ev.value,
+      s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
+    ev.copy(code = code, isNull = "false")
+  }
 }
 
 /**
- * A base class for Explode and PosExplode
+ * A base class for [[Explode]] and [[PosExplode]].
  */
-abstract class ExplodeBase(child: Expression, position: Boolean)
-  extends UnaryExpression with Generator with CodegenFallback with Serializable {
+abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
+  override val inline: Boolean = false
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
+  override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+    case _: ArrayType | _: MapType =>
       TypeCheckResult.TypeCheckSuccess
-    } else {
+    case _ =>
       TypeCheckResult.TypeCheckFailure(
         s"input to function explode should be array or map type, not ${child.dataType}")
-    }
   }
 
   // hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
@@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
     case ArrayType(et, containsNull) =>
       if (position) {
         new StructType()
-          .add("pos", IntegerType, false)
+          .add("pos", IntegerType, nullable = false)
           .add("col", et, containsNull)
       } else {
         new StructType()
@@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
     case MapType(kt, vt, valueContainsNull) =>
       if (position) {
         new StructType()
-          .add("pos", IntegerType, false)
-          .add("key", kt, false)
+          .add("pos", IntegerType, nullable = false)
+          .add("key", kt, nullable = false)
           .add("value", vt, valueContainsNull)
       } else {
         new StructType()
-          .add("key", kt, false)
+          .add("key", kt, nullable = false)
           .add("value", vt, valueContainsNull)
       }
   }
@@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
         }
     }
   }
+
+  override def collectionType: DataType = child.dataType
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    child.genCode(ctx)
+  }
 }
 
 /**
@@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
        20
   """)
 // scalastyle:on line.size.limit
-case class Explode(child: Expression) extends ExplodeBase(child, position = false)
+case class Explode(child: Expression) extends ExplodeBase {
+  override val position: Boolean = false
+}
 
 /**
  * Given an input array produces a sequence of rows for each position and value in the array.
@@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
        1  20
   """)
 // scalastyle:on line.size.limit
-case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
+case class PosExplode(child: Expression) extends ExplodeBase {
+  override val position = true
+}
 
 /**
  * Explodes an array of structs into a table.
@@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
        1  a
        2  b
   """)
-case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
+  override val inline: Boolean = true
+  override val position: Boolean = false
 
   override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
-    case ArrayType(et, _) if et.isInstanceOf[StructType] =>
+    case ArrayType(st: StructType, _) =>
       TypeCheckResult.TypeCheckSuccess
     case _ =>
       TypeCheckResult.TypeCheckFailure(
@@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
   }
 
   override def elementSchema: StructType = child.dataType match {
-    case ArrayType(et : StructType, _) => et
+    case ArrayType(st: StructType, _) => st
   }
 
+  override def collectionType: DataType = child.dataType
+
   private lazy val numFields = elementSchema.fields.length
 
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
         yield inputArray.getStruct(i, numFields)
     }
   }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    child.genCode(ctx)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 1e39b24..2db2a04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -17,7 +17,8 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types.{DataType, IntegerType}
 
 class SubexpressionEliminationSuite extends SparkFunSuite {
   test("Semantic equals and hash") {
@@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
   test("Children of CodegenFallback") {
     val one = Literal(1)
     val two = Add(one, one)
-    val explode = Explode(two)
-    val add = Add(two, explode)
+    val fallback = CodegenFallbackExpression(two)
+    val add = Add(two, fallback)
 
-    var equivalence = new EquivalentExpressions
+    val equivalence = new EquivalentExpressions
     equivalence.addExprTree(add, true)
-    // the `two` inside `explode` should not be added
+    // the `two` inside `fallback` should not be added
     assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
     assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3)  // add, two, explode
   }
 }
+
+case class CodegenFallbackExpression(child: Expression)
+  extends UnaryExpression with CodegenFallback {
+  override def dataType: DataType = child.dataType
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 19fbf0c..f80214a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
 
 /**
  * For lazy computing, be sure the generator.terminate() called in the very last
@@ -40,6 +42,10 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
  * output of each into a new stream of rows.  This operation is similar to a `flatMap` in functional
  * programming with one important additional feature, which allows the input rows to be joined with
  * their output.
+ *
+ * This operator supports whole stage code generation for generators that do not implement
+ * terminate().
+ *
  * @param generator the generator expression
  * @param join  when true, each output row is implicitly joined with the input tuple that produced
  *              it.
@@ -54,7 +60,7 @@ case class GenerateExec(
     outer: Boolean,
     output: Seq[Attribute],
     child: SparkPlan)
-  extends UnaryExecNode {
+  extends UnaryExecNode with CodegenSupport {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -103,5 +109,197 @@ case class GenerateExec(
       }
     }
   }
-}
 
+  override def supportCodegen: Boolean = generator.supportCodegen
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    ctx.currentVars = input
+    ctx.copyResult = true
+
+    // Add input rows to the values when we are joining
+    val values = if (join) {
+      input
+    } else {
+      Seq.empty
+    }
+
+    boundGenerator match {
+      case e: CollectionGenerator => codeGenCollection(ctx, e, values, row)
+      case g => codeGenTraversableOnce(ctx, g, values, row)
+    }
+  }
+
+  /**
+   * Generate code for [[CollectionGenerator]] expressions.
+   */
+  private def codeGenCollection(
+      ctx: CodegenContext,
+      e: CollectionGenerator,
+      input: Seq[ExprCode],
+      row: ExprCode): String = {
+
+    // Generate code for the generator.
+    val data = e.genCode(ctx)
+
+    // Generate looping variables.
+    val index = ctx.freshName("index")
+
+    // Add a check if the generate outer flag is true.
+    val checks = optionalCode(outer, data.isNull)
+
+    // Add position
+    val position = if (e.position) {
+      Seq(ExprCode("", "false", index))
+    } else {
+      Seq.empty
+    }
+
+    // Generate code for either ArrayData or MapData
+    val (initMapData, updateRowData, values) = e.collectionType match {
+      case ArrayType(st: StructType, nullable) if e.inline =>
+        val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks)
+        val fieldChecks = checks ++ optionalCode(nullable, row.isNull)
+        val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) =>
+          codeGenAccessor(ctx, row.value, f.name, i.toString, f.dataType, f.nullable, fieldChecks)
+        }
+        ("", row.code, columns)
+
+      case ArrayType(dataType, nullable) =>
+        ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks)))
+
+      case MapType(keyType, valueType, valueContainsNull) =>
+        // Materialize the key and the value arrays before we enter the loop.
+        val keyArray = ctx.freshName("keyArray")
+        val valueArray = ctx.freshName("valueArray")
+        val initArrayData =
+          s"""
+             |ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray();
+             |ArrayData $valueArray = ${data.isNull} ? null : ${data.value}.valueArray();
+           """.stripMargin
+        val values = Seq(
+          codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks),
+          codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks))
+        (initArrayData, "", values)
+    }
+
+    // In case of outer=true we need to make sure the loop is executed at-least once when the
+    // array/map contains no input. We do this by setting the looping index to -1 if there is no
+    // input, evaluation of the array is prevented by a check in the accessor code.
+    val numElements = ctx.freshName("numElements")
+    val init = if (outer) {
+      s"$numElements == 0 ? -1 : 0"
+    } else {
+      "0"
+    }
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    s"""
+       |${data.code}
+       |$initMapData
+       |int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements();
+       |for (int $index = $init; $index < $numElements; $index++) {
+       |  $numOutput.add(1);
+       |  $updateRowData
+       |  ${consume(ctx, input ++ position ++ values)}
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for a regular [[TraversableOnce]] returning [[Generator]].
+   */
+  private def codeGenTraversableOnce(
+      ctx: CodegenContext,
+      e: Expression,
+      input: Seq[ExprCode],
+      row: ExprCode): String = {
+
+    // Generate the code for the generator
+    val data = e.genCode(ctx)
+
+    // Generate looping variables.
+    val iterator = ctx.freshName("iterator")
+    val hasNext = ctx.freshName("hasNext")
+    val current = ctx.freshName("row")
+
+    // Add a check if the generate outer flag is true.
+    val checks = optionalCode(outer, s"!$hasNext")
+    val values = e.dataType match {
+      case ArrayType(st: StructType, nullable) =>
+        st.fields.toSeq.zipWithIndex.map { case (f, i) =>
+          codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks)
+        }
+    }
+
+    // In case of outer=true we need to make sure the loop is executed at-least-once when the
+    // iterator contains no input. We do this by adding an 'outer' variable which guarantees
+    // execution of the first iteration even if there is no input. Evaluation of the iterator is
+    // prevented by checks in the next() and accessor code.
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    if (outer) {
+      val outerVal = ctx.freshName("outer")
+      s"""
+         |${data.code}
+         |scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
+         |boolean $outerVal = true;
+         |while ($iterator.hasNext() || $outerVal) {
+         |  $numOutput.add(1);
+         |  boolean $hasNext = $iterator.hasNext();
+         |  InternalRow $current = (InternalRow)($hasNext? $iterator.next() : null);
+         |  $outerVal = false;
+         |  ${consume(ctx, input ++ values)}
+         |}
+      """.stripMargin
+    } else {
+      s"""
+         |${data.code}
+         |scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
+         |while ($iterator.hasNext()) {
+         |  $numOutput.add(1);
+         |  InternalRow $current = (InternalRow)($iterator.next());
+         |  ${consume(ctx, input ++ values)}
+         |}
+      """.stripMargin
+    }
+  }
+
+  /**
+   * Generate accessor code for ArrayData and InternalRows.
+   */
+  private def codeGenAccessor(
+      ctx: CodegenContext,
+      source: String,
+      name: String,
+      index: String,
+      dt: DataType,
+      nullable: Boolean,
+      initialChecks: Seq[String]): ExprCode = {
+    val value = ctx.freshName(name)
+    val javaType = ctx.javaType(dt)
+    val getter = ctx.getValue(source, dt, index)
+    val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
+    if (checks.nonEmpty) {
+      val isNull = ctx.freshName("isNull")
+      val code =
+        s"""
+           |boolean $isNull = ${checks.mkString(" || ")};
+           |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
+         """.stripMargin
+      ExprCode(code, isNull, value)
+    } else {
+      ExprCode(s"$javaType $value = $getter;", "false", value)
+    }
+  }
+
+  private def optionalCode(condition: Boolean, code: => String): Seq[String] = {
+    if (condition) Seq(code)
+    else Seq.empty
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index aedc0a8..f0995ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructType}
 
 class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -202,4 +206,34 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
       df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
       Row(1) :: Row(2) :: Nil)
   }
+
+  test("SPARK-14986: Outer lateral view with empty generate expression") {
+    checkAnswer(
+      sql("select nil from values 1 lateral view outer explode(array()) n as nil"),
+      Row(null) :: Nil
+    )
+  }
+
+  test("outer explode()") {
+    checkAnswer(
+      sql("select * from values 1, 2 lateral view outer explode(array()) a as b"),
+      Row(1, null) :: Row(2, null) :: Nil)
+  }
+
+  test("outer generator()") {
+    spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator())
+    checkAnswer(
+      sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
+      Row(1, null) :: Row(2, null) :: Nil)
+  }
+}
+
+case class EmptyGenerator() extends Generator {
+  override def children: Seq[Expression] = Nil
+  override def elementSchema: StructType = new StructType().add("id", IntegerType)
+  override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val iteratorClass = classOf[Iterator[_]].getName
+    ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6b517bc..a715176 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2086,13 +2086,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     }
   }
 
-  test("SPARK-14986: Outer lateral view with empty generate expression") {
-    checkAnswer(
-      sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),
-      Row(null) :: Nil
-    )
-  }
-
   test("data source table created in InMemoryCatalog should be able to read/write") {
     withTable("tbl") {
       sql("CREATE TABLE tbl(i INT, j STRING) USING parquet")

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f26e5e7..e8ea775 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Dataset, Row}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.apache.spark.sql.expressions.scalalang.typed
@@ -113,4 +115,32 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
     assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
   }
+
+  test("generate should be included in WholeStageCodegen") {
+    import org.apache.spark.sql.functions._
+    val ds = spark.range(2).select(
+      col("id"),
+      explode(array(col("id") + 1, col("id") + 2)).as("value"))
+    val plan = ds.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegenExec] &&
+        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined)
+    assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3)))
+  }
+
+  test("large stack generator should not use WholeStageCodegen") {
+    def createStackGenerator(rows: Int): SparkPlan = {
+      val id = UnresolvedAttribute("id")
+      val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i))))
+      spark.range(500).select(Column(stack)).queryExecution.executedPlan
+    }
+    val isCodeGenerated: SparkPlan => Boolean = {
+      case WholeStageCodegenExec(_: GenerateExec) => true
+      case _ => false
+    }
+
+    // Only 'stack' generators that produce 50 rows or less are code generated.
+    assert(createStackGenerator(50).find(isCodeGenerated).isDefined)
+    assert(createStackGenerator(100).find(isCodeGenerated).isEmpty)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7ca7a635/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
index 470c781..01773c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
@@ -102,7 +102,7 @@ class MiscBenchmark extends BenchmarkBase {
     }
     benchmark.run()
 
-    /**
+    /*
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     collect:                            Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
@@ -124,7 +124,7 @@ class MiscBenchmark extends BenchmarkBase {
     }
     benchmark.run()
 
-    /**
+    /*
     model name      : Westmere E56xx/L56xx/X56xx (Nehalem-C)
     collect limit:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
@@ -132,4 +132,99 @@ class MiscBenchmark extends BenchmarkBase {
     collect limit 2 millions                 3348 / 4005          0.3        3193.3       0.2X
      */
   }
+
+  ignore("generate explode") {
+    val N = 1 << 24
+    runBenchmark("generate explode array", N) {
+      val df = sparkSession.range(N).selectExpr(
+        "id as key",
+        "array(rand(), rand(), rand(), rand(), rand()) as values")
+      df.selectExpr("key", "explode(values) value").count()
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    generate explode array:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    generate explode array wholestage off         6920 / 7129          2.4         412.5       1.0X
+    generate explode array wholestage on           623 /  646         26.9          37.1      11.1X
+     */
+
+    runBenchmark("generate explode map", N) {
+      val df = sparkSession.range(N).selectExpr(
+        "id as key",
+        "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs")
+      df.selectExpr("key", "explode(pairs) as (k, v)").count()
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    generate explode map:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    generate explode map wholestage off         11978 / 11993          1.4         714.0       1.0X
+    generate explode map wholestage on             866 /  919         19.4          51.6      13.8X
+     */
+
+    runBenchmark("generate posexplode array", N) {
+      val df = sparkSession.range(N).selectExpr(
+        "id as key",
+        "array(rand(), rand(), rand(), rand(), rand()) as values")
+      df.selectExpr("key", "posexplode(values) as (idx, value)").count()
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    generate posexplode array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    generate posexplode array wholestage off      7502 / 7513          2.2         447.1       1.0X
+    generate posexplode array wholestage on        617 /  623         27.2          36.8      12.2X
+     */
+
+    runBenchmark("generate inline array", N) {
+      val df = sparkSession.range(N).selectExpr(
+        "id as key",
+        "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values")
+      df.selectExpr("key", "inline(values) as (r1, r2)").count()
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    generate inline array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    generate inline array wholestage off          6901 / 6928          2.4         411.3       1.0X
+    generate inline array wholestage on           1001 / 1010         16.8          59.7       6.9X
+     */
+  }
+
+  ignore("generate regular generator") {
+    val N = 1 << 24
+    runBenchmark("generate stack", N) {
+      val df = sparkSession.range(N).selectExpr(
+        "id as key",
+        "id % 2 as t1",
+        "id % 3 as t2",
+        "id % 5 as t3",
+        "id % 7 as t4",
+        "id % 13 as t5")
+      df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count()
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    generate stack:                          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    generate stack wholestage off               12953 / 13070          1.3         772.1       1.0X
+    generate stack wholestage on                   836 /  847         20.1          49.8      15.5X
+     */
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org