You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/24 00:00:17 UTC

spark git commit: [SPARK-13373] [SQL] generate sort merge join

Repository: spark
Updated Branches:
  refs/heads/master c481bdf51 -> 9cdd867da


[SPARK-13373] [SQL] generate sort merge join

## What changes were proposed in this pull request?

Generates code for SortMergeJoin.

## How was the this patch tested?

Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns.

Author: Davies Liu <da...@databricks.com>

Closes #11248 from davies/gen_smj.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9cdd867d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9cdd867d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9cdd867d

Branch: refs/heads/master
Commit: 9cdd867da978629ea2f61f94e3c346fa0bfecf0e
Parents: c481bdf
Author: Davies Liu <da...@databricks.com>
Authored: Tue Feb 23 15:00:10 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Feb 23 15:00:10 2016 -0800

----------------------------------------------------------------------
 .../spark/storage/DiskBlockObjectWriter.scala   |   1 +
 .../sql/execution/BufferedRowIterator.java      |  23 +-
 .../org/apache/spark/sql/execution/Expand.scala |   4 +-
 .../spark/sql/execution/WholeStageCodegen.scala |  71 ++++--
 .../execution/aggregate/TungstenAggregate.scala |   4 +-
 .../spark/sql/execution/basicOperators.scala    |  20 +-
 .../sql/execution/joins/BroadcastHashJoin.scala |   4 +-
 .../sql/execution/joins/CartesianProduct.scala  |   1 -
 .../sql/execution/joins/SortMergeJoin.scala     | 247 ++++++++++++++++++-
 .../execution/BenchmarkWholeStageCodegen.scala  |  34 +++
 .../spark/sql/sources/BucketedReadSuite.scala   |   3 +-
 11 files changed, 360 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index c34d49c..9cc4084 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -203,6 +203,7 @@ private[spark] class DiskBlockObjectWriter(
     numRecordsWritten += 1
     writeMetrics.incRecordsWritten(1)
 
+    // TODO: call updateBytesWritten() less frequently.
     if (numRecordsWritten % 32 == 0) {
       updateBytesWritten()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index ea20115..1d1d7ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -29,12 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 /**
  * An iterator interface used to pull the output from generated function for multiple operators
  * (whole stage codegen).
- *
- * TODO: replaced it by batched columnar format.
  */
-public class BufferedRowIterator {
+public abstract class BufferedRowIterator {
   protected LinkedList<InternalRow> currentRows = new LinkedList<>();
-  protected Iterator<InternalRow> input;
   // used when there is no column in output
   protected UnsafeRow unsafeRow = new UnsafeRow(0);
 
@@ -49,8 +46,16 @@ public class BufferedRowIterator {
     return currentRows.remove();
   }
 
-  public void setInput(Iterator<InternalRow> iter) {
-    input = iter;
+  /**
+   * Initializes from array of iterators of InternalRow.
+   */
+  public abstract void init(Iterator<InternalRow> iters[]);
+
+  /**
+   * Append a row to currentRows.
+   */
+  protected void append(InternalRow row) {
+    currentRows.add(row);
   }
 
   /**
@@ -74,9 +79,5 @@ public class BufferedRowIterator {
    *
    * After it's called, if currentRow is still null, it means no more rows left.
    */
-  protected void processNext() throws IOException {
-    if (input.hasNext()) {
-      currentRows.add(input.next());
-    }
-  }
+  protected abstract void processNext() throws IOException;
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index a9e77ab..12998a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -85,8 +85,8 @@ case class Expand(
     }
   }
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index d79b547..08c52e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
 import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 
 /**
@@ -40,7 +40,8 @@ trait CodegenSupport extends SparkPlan {
   /** Prefix used in the current operator's variable names. */
   private def variablePrefix: String = this match {
     case _: TungstenAggregate => "agg"
-    case _: BroadcastHashJoin => "join"
+    case _: BroadcastHashJoin => "bhj"
+    case _: SortMergeJoin => "smj"
     case _ => nodeName.toLowerCase
   }
 
@@ -68,9 +69,11 @@ trait CodegenSupport extends SparkPlan {
   private var parent: CodegenSupport = null
 
   /**
-    * Returns the RDD of InternalRow which generates the input rows.
+    * Returns all the RDDs of InternalRow which generates the input rows.
+    *
+    * Note: right now we support up to two RDDs.
     */
-  def upstream(): RDD[InternalRow]
+  def upstreams(): Seq[RDD[InternalRow]]
 
   /**
     * Returns Java source code to process the rows from upstream.
@@ -179,19 +182,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
 
   override def supportCodegen: Boolean = false
 
-  override def upstream(): RDD[InternalRow] = {
-    child.execute()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.execute() :: Nil
   }
 
   override def doProduce(ctx: CodegenContext): String = {
+    val input = ctx.freshName("input")
+    // Right now, InputAdapter is only used when there is one upstream.
+    ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
     val row = ctx.freshName("row")
     ctx.INPUT_ROW = row
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
     s"""
-       | while (input.hasNext()) {
-       |   InternalRow $row = (InternalRow) input.next();
+       | while ($input.hasNext()) {
+       |   InternalRow $row = (InternalRow) $input.next();
        |   ${columns.map(_.code).mkString("\n").trim}
        |   ${consume(ctx, columns).trim}
        |   if (shouldStop()) {
@@ -215,7 +222,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
   *
   * -> execute()
   *     |
-  *  doExecute() --------->   upstream() -------> upstream() ------> execute()
+  *  doExecute() --------->   upstreams() -------> upstreams() ------> execute()
   *     |
   *      ----------------->   produce()
   *                             |
@@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
 
         public GeneratedIterator(Object[] references) {
           this.references = references;
+        }
+
+        public void init(scala.collection.Iterator inputs[]) {
           ${ctx.initMutableStates()}
         }
 
@@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     // println(s"${CodeFormatter.format(cleanedSource)}")
     CodeGenerator.compile(cleanedSource)
 
-    plan.upstream().mapPartitions { iter =>
-
-      val clazz = CodeGenerator.compile(source)
-      val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-      buffer.setInput(iter)
-      new Iterator[InternalRow] {
-        override def hasNext: Boolean = buffer.hasNext
-        override def next: InternalRow = buffer.next()
+    val rdds = plan.upstreams()
+    assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
+    if (rdds.length == 1) {
+      rdds.head.mapPartitions { iter =>
+        val clazz = CodeGenerator.compile(cleanedSource)
+        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
+        buffer.init(Array(iter))
+        new Iterator[InternalRow] {
+          override def hasNext: Boolean = buffer.hasNext
+          override def next: InternalRow = buffer.next()
+        }
+      }
+    } else {
+      // Right now, we support up to two upstreams.
+      rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+        val clazz = CodeGenerator.compile(cleanedSource)
+        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
+        buffer.init(Array(leftIter, rightIter))
+        new Iterator[InternalRow] {
+          override def hasNext: Boolean = buffer.hasNext
+          override def next: InternalRow = buffer.next()
+        }
       }
     }
   }
 
-  override def upstream(): RDD[InternalRow] = {
+  override def upstreams(): Seq[RDD[InternalRow]] = {
     throw new UnsupportedOperationException
   }
 
@@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     if (row != null) {
       // There is an UnsafeRow already
       s"""
-         | currentRows.add($row.copy());
+         |append($row.copy());
        """.stripMargin
     } else {
       assert(input != null)
@@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
         ctx.currentVars = input
         val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
         s"""
-           | ${code.code.trim}
-           | currentRows.add(${code.value}.copy());
+           |${code.code.trim}
+           |append(${code.value}.copy());
          """.stripMargin
       } else {
         // There is no columns
         s"""
-           | currentRows.add(unsafeRow);
+           |append(unsafeRow);
          """.stripMargin
       }
     }
@@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
               b.copy(left = apply(left))
             case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
               b.copy(right = apply(right))
+            case j @ SortMergeJoin(_, _, _, left, right) =>
+              // The children of SortMergeJoin should do codegen separately.
+              j.copy(left = apply(left), right = apply(right))
             case p if !supportCodegen(p) =>
               val input = apply(p)  // collapse them recursively
               inputs += input

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 852203f..a467229 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -121,8 +121,8 @@ case class TungstenAggregate(
     !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
   }
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 55bddd1..b2f443c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
 
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {
@@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
   private[sql] override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {
@@ -156,8 +156,9 @@ case class Range(
   private[sql] override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  override def upstream(): RDD[InternalRow] = {
-    sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
+      .map(i => InternalRow(i)) :: Nil
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {
@@ -213,12 +214,15 @@ case class Range(
         | }
        """.stripMargin)
 
+    val input = ctx.freshName("input")
+    // Right now, Range is only used when there is one upstream.
+    ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
     s"""
       | // initialize Range
       | if (!$initTerm) {
       |   $initTerm = true;
-      |   if (input.hasNext()) {
-      |     initRange(((InternalRow) input.next()).getInt(0));
+      |   if ($input.hasNext()) {
+      |     initRange(((InternalRow) $input.next()).getInt(0));
       |   } else {
       |     return;
       |   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index ddc0882..6699dba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -99,8 +99,8 @@ case class BroadcastHashJoin(
     }
   }
 
-  override def upstream(): RDD[InternalRow] = {
-    streamedPlan.asInstanceOf[CodegenSupport].upstream()
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    streamedPlan.asInstanceOf[CodegenSupport].upstreams()
   }
 
   override def doProduce(ctx: CodegenContext): String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index e417079..fabd2fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
-
 /**
   * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
   * will be much faster than building the right partition for every row in left RDD, it also

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index cd8a567..7ec4027 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
  * Performs an sort merge join of two child relations.
@@ -34,7 +35,7 @@ case class SortMergeJoin(
     rightKeys: Seq[Expression],
     condition: Option[Expression],
     left: SparkPlan,
-    right: SparkPlan) extends BinaryNode {
+    right: SparkPlan) extends BinaryNode with CodegenSupport {
 
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -125,6 +126,246 @@ case class SortMergeJoin(
       }.toScala
     }
   }
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    left.execute() :: right.execute() :: Nil
+  }
+
+  private def createJoinKey(
+      ctx: CodegenContext,
+      row: String,
+      keys: Seq[Expression],
+      input: Seq[Attribute]): Seq[ExprCode] = {
+    ctx.INPUT_ROW = row
+    keys.map(BindReferences.bindReference(_, input).gen(ctx))
+  }
+
+  private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
+    vars.zipWithIndex.map { case (ev, i) =>
+      val value = ctx.freshName("value")
+      ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
+      val code =
+        s"""
+           |$value = ${ev.value};
+         """.stripMargin
+      ExprCode(code, "false", value)
+    }
+  }
+
+  private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
+    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
+      s"""
+         |if (comp == 0) {
+         |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
+         |}
+       """.stripMargin.trim
+    }
+    s"""
+       |comp = 0;
+       |${comparisons.mkString("\n")}
+     """.stripMargin
+  }
+
+  /**
+    * Generate a function to scan both left and right to find a match, returns the term for
+    * matched one row from left side and buffered rows from right side.
+    */
+  private def genScanner(ctx: CodegenContext): (String, String) = {
+    // Create class member for next row from both sides.
+    val leftRow = ctx.freshName("leftRow")
+    ctx.addMutableState("InternalRow", leftRow, "")
+    val rightRow = ctx.freshName("rightRow")
+    ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
+
+    // Create variables for join keys from both sides.
+    val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
+    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+    val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
+    val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
+    // Copy the right key as class members so they could be used in next function call.
+    val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
+
+    // A list to hold all matched rows from right side.
+    val matches = ctx.freshName("matches")
+    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
+    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+    // Copy the left keys as class members so they could be used in next function call.
+    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
+
+    ctx.addNewFunction("findNextInnerJoinRows",
+      s"""
+         |private boolean findNextInnerJoinRows(
+         |    scala.collection.Iterator leftIter,
+         |    scala.collection.Iterator rightIter) {
+         |  $leftRow = null;
+         |  int comp = 0;
+         |  while ($leftRow == null) {
+         |    if (!leftIter.hasNext()) return false;
+         |    $leftRow = (InternalRow) leftIter.next();
+         |    ${leftKeyVars.map(_.code).mkString("\n")}
+         |    if ($leftAnyNull) {
+         |      $leftRow = null;
+         |      continue;
+         |    }
+         |    if (!$matches.isEmpty()) {
+         |      ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
+         |      if (comp == 0) {
+         |        return true;
+         |      }
+         |      $matches.clear();
+         |    }
+         |
+         |    do {
+         |      if ($rightRow == null) {
+         |        if (!rightIter.hasNext()) {
+         |          ${matchedKeyVars.map(_.code).mkString("\n")}
+         |          return !$matches.isEmpty();
+         |        }
+         |        $rightRow = (InternalRow) rightIter.next();
+         |        ${rightKeyTmpVars.map(_.code).mkString("\n")}
+         |        if ($rightAnyNull) {
+         |          $rightRow = null;
+         |          continue;
+         |        }
+         |        ${rightKeyVars.map(_.code).mkString("\n")}
+         |      }
+         |      ${genComparision(ctx, leftKeyVars, rightKeyVars)}
+         |      if (comp > 0) {
+         |        $rightRow = null;
+         |      } else if (comp < 0) {
+         |        if (!$matches.isEmpty()) {
+         |          ${matchedKeyVars.map(_.code).mkString("\n")}
+         |          return true;
+         |        }
+         |        $leftRow = null;
+         |      } else {
+         |        $matches.add($rightRow.copy());
+         |        $rightRow = null;;
+         |      }
+         |    } while ($leftRow != null);
+         |  }
+         |  return false; // unreachable
+         |}
+       """.stripMargin)
+
+    (leftRow, matches)
+  }
+
+  /**
+    * Creates variables for left part of result row.
+    *
+    * In order to defer the access after condition and also only access once in the loop,
+    * the variables should be declared separately from accessing the columns, we can't use the
+    * codegen of BoundReference here.
+    */
+  private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
+    ctx.INPUT_ROW = leftRow
+    left.output.zipWithIndex.map { case (a, i) =>
+      val value = ctx.freshName("value")
+      val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
+      // declare it as class member, so we can access the column before or in the loop.
+      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+      if (a.nullable) {
+        val isNull = ctx.freshName("isNull")
+        ctx.addMutableState("boolean", isNull, "")
+        val code =
+          s"""
+             |$isNull = $leftRow.isNullAt($i);
+             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+           """.stripMargin
+        ExprCode(code, isNull, value)
+      } else {
+        ExprCode(s"$value = $valueCode;", "false", value)
+      }
+    }
+  }
+
+  /**
+    * Creates the variables for right part of result row, using BoundReference, since the right
+    * part are accessed inside the loop.
+    */
+  private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
+    ctx.INPUT_ROW = rightRow
+    right.output.zipWithIndex.map { case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).gen(ctx)
+    }
+  }
+
+  /**
+    * Splits variables based on whether it's used by condition or not, returns the code to create
+    * these variables before the condition and after the condition.
+    *
+    * Only a few columns are used by condition, then we can skip the accessing of those columns
+    * that are not used by condition also filtered out by condition.
+    */
+  private def splitVarsByCondition(
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode]): (String, String) = {
+    if (condition.isDefined) {
+      val condRefs = condition.get.references
+      val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
+        condRefs.contains(a)
+      }
+      val beforeCond = used.map(_._2.code).mkString("\n")
+      val afterCond = notUsed.map(_._2.code).mkString("\n")
+      (beforeCond, afterCond)
+    } else {
+      (variables.map(_.code).mkString("\n"), "")
+    }
+  }
+
+  override def doProduce(ctx: CodegenContext): String = {
+    val leftInput = ctx.freshName("leftInput")
+    ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
+    val rightInput = ctx.freshName("rightInput")
+    ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
+
+    val (leftRow, matches) = genScanner(ctx)
+
+    // Create variables for row from both sides.
+    val leftVars = createLeftVars(ctx, leftRow)
+    val rightRow = ctx.freshName("rightRow")
+    val rightVars = createRightVar(ctx, rightRow)
+    val resultVars = leftVars ++ rightVars
+
+    // Check condition
+    ctx.currentVars = resultVars
+    val cond = if (condition.isDefined) {
+      BindReferences.bindReference(condition.get, output).gen(ctx)
+    } else {
+      ExprCode("", "false", "true")
+    }
+    // Split the code of creating variables based on whether it's used by condition or not.
+    val loaded = ctx.freshName("loaded")
+    val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+    val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+
+
+    val size = ctx.freshName("size")
+    val i = ctx.freshName("i")
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    s"""
+       |while (findNextInnerJoinRows($leftInput, $rightInput)) {
+       |  int $size = $matches.size();
+       |  boolean $loaded = false;
+       |  $leftBefore
+       |  for (int $i = 0; $i < $size; $i ++) {
+       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
+       |    $rightBefore
+       |    ${cond.code}
+       |    if (${cond.isNull} || !${cond.value}) continue;
+       |    if (!$loaded) {
+       |      $loaded = true;
+       |      $leftAfter
+       |    }
+       |    $rightAfter
+       |    $numOutput.add(1);
+       |    ${consume(ctx, resultVars)}
+       |  }
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index bcac660..6d6cc01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark
 class BenchmarkWholeStageCodegen extends SparkFunSuite {
   lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
     .set("spark.sql.shuffle.partitions", "1")
+    .set("spark.sql.autoBroadcastJoinThreshold", "0")
   lazy val sc = SparkContext.getOrCreate(conf)
   lazy val sqlContext = SQLContext.getOrCreate(sc)
 
@@ -187,6 +188,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
       */
   }
 
+  ignore("sort merge join") {
+    val N = 2 << 20
+    runBenchmark("merge join", N) {
+      val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1")
+      val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2")
+      df1.join(df2, col("k1") === col("k2")).count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    merge join:                         Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    merge join codegen=false                 1588 / 1880          1.3         757.1       1.0X
+    merge join codegen=true                  1477 / 1531          1.4         704.2       1.1X
+      */
+
+    runBenchmark("sort merge join", N) {
+      val df1 = sqlContext.range(N)
+        .selectExpr(s"(id * 15485863) % ${N*10} as k1")
+      val df2 = sqlContext.range(N)
+        .selectExpr(s"(id * 15485867) % ${N*10} as k2")
+      df1.join(df2, col("k1") === col("k2")).count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    sort merge join:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    sort merge join codegen=false            3626 / 3667          0.6        1728.9       1.0X
+    sort merge join codegen=true             3405 / 3438          0.6        1623.8       1.1X
+      */
+  }
+
   ignore("rube") {
     val N = 5 << 20
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9cdd867d/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index a05a57c..0ef42f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -240,7 +240,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
       withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
       withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
 
-      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
         val t1 = hiveContext.table("bucketed_table1")
         val t2 = hiveContext.table("bucketed_table2")
         val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org