You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/03 19:39:01 UTC

spark git commit: [SPARK-12798] [SQL] generated BroadcastHashJoin

Repository: spark
Updated Branches:
  refs/heads/master e9eb248ed -> c4feec26e


[SPARK-12798] [SQL] generated BroadcastHashJoin

A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate).

Author: Davies Liu <da...@databricks.com>

Closes #10989 from davies/gen_join.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4feec26
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4feec26
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4feec26

Branch: refs/heads/master
Commit: c4feec26eb677bfd3bfac38e5e28eae05279956e
Parents: e9eb248
Author: Davies Liu <da...@databricks.com>
Authored: Wed Feb 3 10:38:53 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Feb 3 10:38:53 2016 -0800

----------------------------------------------------------------------
 .../sql/execution/BufferedRowIterator.java      | 30 +++++--
 .../spark/sql/execution/WholeStageCodegen.scala | 18 ++--
 .../execution/aggregate/TungstenAggregate.scala |  4 +-
 .../spark/sql/execution/basicOperators.scala    |  2 +
 .../sql/execution/joins/BroadcastHashJoin.scala | 92 +++++++++++++++++++-
 .../execution/BenchmarkWholeStageCodegen.scala  | 28 +++++-
 .../sql/execution/WholeStageCodegenSuite.scala  | 15 +++-
 7 files changed, 169 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index 6acf70d..ea20115 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -18,9 +18,11 @@
 package org.apache.spark.sql.execution;
 
 import java.io.IOException;
+import java.util.LinkedList;
 
 import scala.collection.Iterator;
 
+import org.apache.spark.TaskContext;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 
@@ -31,22 +33,20 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
  * TODO: replaced it by batched columnar format.
  */
 public class BufferedRowIterator {
-  protected InternalRow currentRow;
+  protected LinkedList<InternalRow> currentRows = new LinkedList<>();
   protected Iterator<InternalRow> input;
   // used when there is no column in output
   protected UnsafeRow unsafeRow = new UnsafeRow(0);
 
   public boolean hasNext() throws IOException {
-    if (currentRow == null) {
+    if (currentRows.isEmpty()) {
       processNext();
     }
-    return currentRow != null;
+    return !currentRows.isEmpty();
   }
 
   public InternalRow next() {
-    InternalRow r = currentRow;
-    currentRow = null;
-    return r;
+    return currentRows.remove();
   }
 
   public void setInput(Iterator<InternalRow> iter) {
@@ -54,13 +54,29 @@ public class BufferedRowIterator {
   }
 
   /**
+   * Returns whether `processNext()` should stop processing next row from `input` or not.
+   *
+   * If it returns true, the caller should exit the loop (return from processNext()).
+   */
+  protected boolean shouldStop() {
+    return !currentRows.isEmpty();
+  }
+
+  /**
+   * Increase the peak execution memory for current task.
+   */
+  protected void incPeakExecutionMemory(long size) {
+    TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
+  }
+
+  /**
    * Processes the input until have a row as output (currentRow).
    *
    * After it's called, if currentRow is still null, it means no more rows left.
    */
   protected void processNext() throws IOException {
     if (input.hasNext()) {
-      currentRow = input.next();
+      currentRows.add(input.next());
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1475496..131efea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
 import org.apache.spark.util.Utils
 
 /**
@@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
        |   InternalRow $row = (InternalRow) input.next();
        |   ${columns.map(_.code).mkString("\n").trim}
        |   ${consume(ctx, columns).trim}
+       |   if (shouldStop()) {
+       |     return;
+       |   }
        | }
      """.stripMargin
   }
@@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     if (row != null) {
       // There is an UnsafeRow already
       s"""
-         | currentRow = $row;
-         | return;
+         | currentRows.add($row.copy());
        """.stripMargin
     } else {
       assert(input != null)
@@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
         val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
         s"""
            | ${code.code.trim}
-           | currentRow = ${code.value};
-           | return;
+           | currentRows.add(${code.value}.copy());
          """.stripMargin
       } else {
         // There is no columns
         s"""
-           | currentRow = unsafeRow;
-           | return;
+           | currentRows.add(unsafeRow);
          """.stripMargin
       }
     }
@@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
 
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
+            // The build side can't be compiled together
+            case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+              b.copy(left = apply(left))
+            case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+              b.copy(right = apply(right))
             case p if !supportCodegen(p) =>
               val input = apply(p)  // collapse them recursively
               inputs += input

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index d024477..9d9f14f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -471,6 +471,8 @@ case class TungstenAggregate(
        UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
        UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
        $outputCode
+
+       if (shouldStop()) return;
      }
 
      $iterTerm.close();
@@ -480,7 +482,7 @@ case class TungstenAggregate(
      """
   }
 
-  private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
 
     // create grouping key
     ctx.currentVars = input

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ae44221..6e51c4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -237,6 +237,8 @@ case class Range(
       |    $overflow = true;
       |  }
       |  ${consume(ctx, Seq(ev))}
+      |
+      |  if (shouldStop()) return;
       | }
      """.stripMargin
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 0464071..8b275e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins
 import scala.concurrent._
 import scala.concurrent.duration._
 
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.collection.CompactBuffer
 
 /**
  * Performs an inner hash join of two child relations.  When the output RDD of this operator is
@@ -42,7 +45,7 @@ case class BroadcastHashJoin(
     condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan)
-  extends BinaryNode with HashJoin {
+  extends BinaryNode with HashJoin with CodegenSupport {
 
   override private[sql] lazy val metrics = Map(
     "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
@@ -117,6 +120,87 @@ case class BroadcastHashJoin(
       hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
     }
   }
+
+  // the term for hash relation
+  private var relationTerm: String = _
+
+  override def upstream(): RDD[InternalRow] = {
+    streamedPlan.asInstanceOf[CodegenSupport].upstream()
+  }
+
+  override def doProduce(ctx: CodegenContext): String = {
+    // create a name for HashRelation
+    val broadcastRelation = Await.result(broadcastFuture, timeout)
+    val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
+    relationTerm = ctx.freshName("relation")
+    // TODO: create specialized HashRelation for single join key
+    val clsName = classOf[UnsafeHashedRelation].getName
+    ctx.addMutableState(clsName, relationTerm,
+      s"""
+         | $relationTerm = ($clsName) $broadcast.value();
+         | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+       """.stripMargin)
+
+    s"""
+       | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
+     """.stripMargin
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    // generate the key as UnsafeRow
+    ctx.currentVars = input
+    val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+    val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+    val keyTerm = keyVal.value
+    val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+
+    // find the matches from HashedRelation
+    val matches = ctx.freshName("matches")
+    val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+    val i = ctx.freshName("i")
+    val size = ctx.freshName("size")
+    val row = ctx.freshName("row")
+
+    // create variables for output
+    ctx.currentVars = null
+    ctx.INPUT_ROW = row
+    val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).gen(ctx)
+    }
+    val resultVars = buildSide match {
+      case BuildLeft => buildColumns ++ input
+      case BuildRight => input ++ buildColumns
+    }
+
+    val ouputCode = if (condition.isDefined) {
+      // filter the output via condition
+      ctx.currentVars = resultVars
+      val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+      s"""
+         | ${ev.code}
+         | if (!${ev.isNull} && ${ev.value}) {
+         |   ${consume(ctx, resultVars)}
+         | }
+       """.stripMargin
+    } else {
+      consume(ctx, resultVars)
+    }
+
+    s"""
+       | // generate join key
+       | ${keyVal.code}
+       | // find matches from HashRelation
+       | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
+       | if ($matches != null) {
+       |   int $size = $matches.size();
+       |   for (int $i = 0; $i < $size; $i++) {
+       |     UnsafeRow $row = (UnsafeRow) $matches.apply($i);
+       |     ${buildColumns.map(_.code).mkString("\n")}
+       |     $ouputCode
+       |   }
+       | }
+     """.stripMargin
+  }
 }
 
 object BroadcastHashJoin {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index ec2b9ab..15ba773 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.functions._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     benchmark.run()
   }
 
+  def testBroadcastHashJoin(values: Int): Unit = {
+    val benchmark = new Benchmark("BroadcastHashJoin", values)
+
+    val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+
+    benchmark.addCase("BroadcastHashJoin w/o codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
+    }
+    benchmark.addCase(s"BroadcastHashJoin w codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
+    }
+
+    /*
+      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+      BroadcastHashJoin:                 Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+      -------------------------------------------------------------------------------
+      BroadcastHashJoin w/o codegen           3053.41             3.43         1.00 X
+      BroadcastHashJoin w codegen             1028.40            10.20         2.97 X
+    */
+    benchmark.run()
+  }
+
   def testBytesToBytesMap(values: Int): Unit = {
     val benchmark = new Benchmark("BytesToBytesMap", values)
 
@@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     // testWholeStage(200 << 20)
     // testStatFunctions(20 << 20)
     // testAggregateWithKey(20 << 20)
-    // testBytesToBytesMap(1024 * 1024 * 50)
+    // testBytesToBytesMap(50 << 20)
+    // testBroadcastHashJoin(10 << 20)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4feec26/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c251650..9350205 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.functions.{avg, col, max}
+import org.apache.spark.sql.execution.joins.BroadcastHashJoin
+import org.apache.spark.sql.functions.{avg, broadcast, col, max}
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
 
 class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
 
@@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
   }
+
+  test("BroadcastHashJoin should be included in WholeStageCodegen") {
+    val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
+    val schema = new StructType().add("k", IntegerType).add("v", StringType)
+    val smallDF = sqlContext.createDataFrame(rdd, schema)
+    val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
+    assert(df.queryExecution.executedPlan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org