You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/30 05:16:15 UTC

spark git commit: [SPARK-12914] [SQL] generate aggregation with grouping keys

Repository: spark
Updated Branches:
  refs/heads/master 12252d1da -> e6a02c66d


[SPARK-12914] [SQL] generate aggregation with grouping keys

This PR add support for grouping keys for generated TungstenAggregate.

Spilling and performance improvements for BytesToBytesMap will be done by followup PR.

Author: Davies Liu <da...@databricks.com>

Closes #10855 from davies/gen_keys.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6a02c66
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6a02c66
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6a02c66

Branch: refs/heads/master
Commit: e6a02c66d53f59ba2d5c1548494ae80a385f9f5c
Parents: 12252d1
Author: Davies Liu <da...@databricks.com>
Authored: Fri Jan 29 20:16:11 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Jan 29 20:16:11 2016 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  47 ++++
 .../codegen/GenerateMutableProjection.scala     |  27 +--
 .../sql/execution/BufferedRowIterator.java      |   6 +-
 .../execution/aggregate/TungstenAggregate.scala | 238 +++++++++++++++++--
 .../execution/BenchmarkWholeStageCodegen.scala  | 119 +++++++++-
 .../sql/execution/WholeStageCodegenSuite.scala  |   9 +
 6 files changed, 393 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e6704cf..21f9198 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -56,6 +56,20 @@ class CodegenContext {
   val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
 
   /**
+    * Add an object to `references`, create a class member to access it.
+    *
+    * Returns the name of class member.
+    */
+  def addReferenceObj(name: String, obj: Any, className: String = null): String = {
+    val term = freshName(name)
+    val idx = references.length
+    references += obj
+    val clsName = Option(className).getOrElse(obj.getClass.getName)
+    addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
+    term
+  }
+
+  /**
     * Holding a list of generated columns as input of current operator, will be used by
     * BoundReference to generate code.
     */
@@ -199,6 +213,39 @@ class CodegenContext {
   }
 
   /**
+    * Update a column in MutableRow from ExprCode.
+    */
+  def updateColumn(
+      row: String,
+      dataType: DataType,
+      ordinal: Int,
+      ev: ExprCode,
+      nullable: Boolean): String = {
+    if (nullable) {
+      // Can't call setNullAt on DecimalType, because we need to keep the offset
+      if (dataType.isInstanceOf[DecimalType]) {
+        s"""
+           if (!${ev.isNull}) {
+             ${setColumn(row, dataType, ordinal, ev.value)};
+           } else {
+             ${setColumn(row, dataType, ordinal, "null")};
+           }
+         """
+      } else {
+        s"""
+           if (!${ev.isNull}) {
+             ${setColumn(row, dataType, ordinal, ev.value)};
+           } else {
+             $row.setNullAt($ordinal);
+           }
+         """
+      }
+    } else {
+      s"""${setColumn(row, dataType, ordinal, ev.value)};"""
+    }
+  }
+
+  /**
    * Returns the name used in accessor and setter for a Java primitive type.
    */
   def primitiveTypeName(jt: String): String = jt match {

http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ec31db1..5b4dc8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
 
     val updates = validExpr.zip(index).map {
       case (e, i) =>
-        if (e.nullable) {
-          if (e.dataType.isInstanceOf[DecimalType]) {
-            // Can't call setNullAt on DecimalType, because we need to keep the offset
-            s"""
-              if (this.isNull_$i) {
-                ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
-              } else {
-                ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
-              }
-            """
-          } else {
-            s"""
-              if (this.isNull_$i) {
-                mutableRow.setNullAt($i);
-              } else {
-                ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
-              }
-            """
-          }
-        } else {
-          s"""
-            ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
-          """
-        }
-
+        val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
+        ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
     }
 
     val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)

http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index b1bbb1d..6acf70d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution;
 
+import java.io.IOException;
+
 import scala.collection.Iterator;
 
 import org.apache.spark.sql.catalyst.InternalRow;
@@ -34,7 +36,7 @@ public class BufferedRowIterator {
   // used when there is no column in output
   protected UnsafeRow unsafeRow = new UnsafeRow(0);
 
-  public boolean hasNext() {
+  public boolean hasNext() throws IOException {
     if (currentRow == null) {
       processNext();
     }
@@ -56,7 +58,7 @@ public class BufferedRowIterator {
    *
    * After it's called, if currentRow is still null, it means no more rows left.
    */
-  protected void processNext() {
+  protected void processNext() throws IOException {
     if (input.hasNext()) {
       currentRow = input.next();
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index ff2f38b..57db726 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,16 +17,18 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DecimalType, StructType}
+import org.apache.spark.unsafe.KVIterator
 
 case class TungstenAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -114,22 +116,38 @@ case class TungstenAggregate(
     }
   }
 
+  // all the mode of aggregate expressions
+  private val modes = aggregateExpressions.map(_.mode).distinct
+
   override def supportCodegen: Boolean = {
-    groupingExpressions.isEmpty &&
-      // ImperativeAggregate is not supported right now
-      !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+    // ImperativeAggregate is not supported right now
+    !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
   }
 
-  // The variables used as aggregation buffer
-  private var bufVars: Seq[ExprCode] = _
-
-  private val modes = aggregateExpressions.map(_.mode).distinct
-
   override def upstream(): RDD[InternalRow] = {
     child.asInstanceOf[CodegenSupport].upstream()
   }
 
   protected override def doProduce(ctx: CodegenContext): String = {
+    if (groupingExpressions.isEmpty) {
+      doProduceWithoutKeys(ctx)
+    } else {
+      doProduceWithKeys(ctx)
+    }
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    if (groupingExpressions.isEmpty) {
+      doConsumeWithoutKeys(ctx, input)
+    } else {
+      doConsumeWithKeys(ctx, input)
+    }
+  }
+
+  // The variables used as aggregation buffer
+  private var bufVars: Seq[ExprCode] = _
+
+  private def doProduceWithoutKeys(ctx: CodegenContext): String = {
     val initAgg = ctx.freshName("initAgg")
     ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
 
@@ -176,10 +194,10 @@ case class TungstenAggregate(
       (resultVars, resultVars.map(_.code).mkString("\n"))
     }
 
-    val doAgg = ctx.freshName("doAgg")
+    val doAgg = ctx.freshName("doAggregateWithoutKey")
     ctx.addNewFunction(doAgg,
       s"""
-         | private void $doAgg() {
+         | private void $doAgg() throws java.io.IOException {
          |   // initialize aggregation buffer
          |   ${bufVars.map(_.code).mkString("\n")}
          |
@@ -200,7 +218,7 @@ case class TungstenAggregate(
      """.stripMargin
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // only have DeclarativeAggregate
     val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
     val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
@@ -212,7 +230,6 @@ case class TungstenAggregate(
           e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
       }
     }
-
     ctx.currentVars = bufVars ++ input
     // TODO: support subexpression elimination
     val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
@@ -232,6 +249,199 @@ case class TungstenAggregate(
      """.stripMargin
   }
 
+  private val groupingAttributes = groupingExpressions.map(_.toAttribute)
+  private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+  private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
+    .filter(_.isInstanceOf[DeclarativeAggregate])
+    .map(_.asInstanceOf[DeclarativeAggregate])
+  private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
+  private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+
+  // The name for HashMap
+  private var hashMapTerm: String = _
+
+  /**
+    * This is called by generated Java class, should be public.
+    */
+  def createHashMap(): UnsafeFixedWidthAggregationMap = {
+    // create initialized aggregate buffer
+    val initExpr = declFunctions.flatMap(f => f.initialValues)
+    val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+
+    // create hashMap
+    new UnsafeFixedWidthAggregationMap(
+      initialBuffer,
+      bufferSchema,
+      groupingKeySchema,
+      TaskContext.get().taskMemoryManager(),
+      1024 * 16, // initial capacity
+      TaskContext.get().taskMemoryManager().pageSizeBytes,
+      false // disable tracking of performance metrics
+    )
+  }
+
+  /**
+    * This is called by generated Java class, should be public.
+    */
+  def createUnsafeJoiner(): UnsafeRowJoiner = {
+    GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+  }
+
+
+  /**
+    * Update peak execution memory, called in generated Java class.
+    */
+  def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
+    val mapMemory = hashMap.getPeakMemoryUsedBytes
+    val metrics = TaskContext.get().taskMetrics()
+    metrics.incPeakExecutionMemory(mapMemory)
+  }
+
+  private def doProduceWithKeys(ctx: CodegenContext): String = {
+    val initAgg = ctx.freshName("initAgg")
+    ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+    // create hashMap
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    hashMapTerm = ctx.freshName("hashMap")
+    val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+    ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+
+    // Create a name for iterator from HashMap
+    val iterTerm = ctx.freshName("mapIter")
+    ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
+
+    // generate code for output
+    val keyTerm = ctx.freshName("aggKey")
+    val bufferTerm = ctx.freshName("aggBuffer")
+    val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
+      // generate output using resultExpressions
+      ctx.currentVars = null
+      ctx.INPUT_ROW = keyTerm
+      val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
+          BoundReference(i, e.dataType, e.nullable).gen(ctx)
+      }
+      ctx.INPUT_ROW = bufferTerm
+      val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+        BoundReference(i, e.dataType, e.nullable).gen(ctx)
+      }
+      // evaluate the aggregation result
+      ctx.currentVars = bufferVars
+      val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
+        BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+      }
+      // generate the final result
+      ctx.currentVars = keyVars ++ aggResults
+      val inputAttrs = groupingAttributes ++ aggregateAttributes
+      val resultVars = resultExpressions.map { e =>
+        BindReferences.bindReference(e, inputAttrs).gen(ctx)
+      }
+      s"""
+       ${keyVars.map(_.code).mkString("\n")}
+       ${bufferVars.map(_.code).mkString("\n")}
+       ${aggResults.map(_.code).mkString("\n")}
+       ${resultVars.map(_.code).mkString("\n")}
+
+       ${consume(ctx, resultVars)}
+       """
+
+    } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+      // This should be the last operator in a stage, we should output UnsafeRow directly
+      val joinerTerm = ctx.freshName("unsafeRowJoiner")
+      ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
+        s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
+      val resultRow = ctx.freshName("resultRow")
+      s"""
+       UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
+       ${consume(ctx, null, resultRow)}
+       """
+
+    } else {
+      // generate result based on grouping key
+      ctx.INPUT_ROW = keyTerm
+      ctx.currentVars = null
+      val eval = resultExpressions.map{ e =>
+        BindReferences.bindReference(e, groupingAttributes).gen(ctx)
+      }
+      s"""
+       ${eval.map(_.code).mkString("\n")}
+       ${consume(ctx, eval)}
+       """
+    }
+
+    val doAgg = ctx.freshName("doAggregateWithKeys")
+    ctx.addNewFunction(doAgg,
+      s"""
+        private void $doAgg() throws java.io.IOException {
+          ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+
+          $iterTerm = $hashMapTerm.iterator();
+        }
+       """)
+
+    s"""
+     if (!$initAgg) {
+       $initAgg = true;
+       $doAgg();
+     }
+
+     // output the result
+     while ($iterTerm.next()) {
+       UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
+       UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
+       $outputCode
+     }
+
+     $thisPlan.updatePeakMemory($hashMapTerm);
+     $hashMapTerm.free();
+     """
+  }
+
+  private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
+
+    // create grouping key
+    ctx.currentVars = input
+    val keyCode = GenerateUnsafeProjection.createCode(
+      ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+    val key = keyCode.value
+    val buffer = ctx.freshName("aggBuffer")
+
+    // only have DeclarativeAggregate
+    val updateExpr = aggregateExpressions.flatMap { e =>
+      e.mode match {
+        case Partial | Complete =>
+          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+        case PartialMerge | Final =>
+          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+      }
+    }
+
+    val inputAttr = bufferAttributes ++ child.output
+    ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+    ctx.INPUT_ROW = buffer
+    // TODO: support subexpression elimination
+    val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
+    val updates = evals.zipWithIndex.map { case (ev, i) =>
+      val dt = updateExpr(i).dataType
+      ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
+    }
+
+    s"""
+     // generate grouping key
+     ${keyCode.code}
+     UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+     if ($buffer == null) {
+       // failed to allocate the first page
+       throw new OutOfMemoryError("No enough memory for aggregation");
+     }
+
+     // evaluate aggregate function
+     ${evals.map(_.code).mkString("\n")}
+     // update aggregate buffer
+     ${updates.mkString("\n")}
+     """
+  }
+
   override def simpleString: String = {
     val allAggregateExpressions = aggregateExpressions
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index c4aad39..2f09c8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -18,7 +18,12 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.Benchmark
 
 /**
@@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark
   *  build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
   */
 class BenchmarkWholeStageCodegen extends SparkFunSuite {
-  def testWholeStage(values: Int): Unit = {
-    val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
-    val sc = SparkContext.getOrCreate(conf)
-    val sqlContext = SQLContext.getOrCreate(sc)
+  lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
+  lazy val sc = SparkContext.getOrCreate(conf)
+  lazy val sqlContext = SQLContext.getOrCreate(sc)
 
-    val benchmark = new Benchmark("Single Int Column Scan", values)
+  def testWholeStage(values: Int): Unit = {
+    val benchmark = new Benchmark("rang/filter/aggregate", values)
 
-    benchmark.addCase("Without whole stage codegen") { iter =>
+    benchmark.addCase("Without codegen") { iter =>
       sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
       sqlContext.range(values).filter("(id & 1) = 1").count()
     }
 
-    benchmark.addCase("With whole stage codegen") { iter =>
+    benchmark.addCase("With codegen") { iter =>
       sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
       sqlContext.range(values).filter("(id & 1) = 1").count()
     }
 
     /*
       Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-      Single Int Column Scan:            Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+      rang/filter/aggregate:            Avg Time(ms)    Avg Rate(M/s)  Relative Rate
       -------------------------------------------------------------------------------
-      Without whole stage codegen             7775.53            26.97         1.00 X
-      With whole stage codegen                 342.15           612.94        22.73 X
+      Without codegen             7775.53            26.97         1.00 X
+      With codegen                 342.15           612.94        22.73 X
     */
     benchmark.run()
   }
 
-  ignore("benchmark") {
-    testWholeStage(1024 * 1024 * 200)
+  def testAggregateWithKey(values: Int): Unit = {
+    val benchmark = new Benchmark("Aggregate with keys", values)
+
+    benchmark.addCase("Aggregate w/o codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+    }
+    benchmark.addCase(s"Aggregate w codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+    }
+
+    /*
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Aggregate with keys:               Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+    -------------------------------------------------------------------------------
+    Aggregate w/o codegen                   4254.38             4.93         1.00 X
+    Aggregate w codegen                     2661.45             7.88         1.60 X
+    */
+    benchmark.run()
+  }
+
+  def testBytesToBytesMap(values: Int): Unit = {
+    val benchmark = new Benchmark("BytesToBytesMap", values)
+
+    benchmark.addCase("hash") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val key = new UnsafeRow(1)
+      key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      val value = new UnsafeRow(2)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      var s = 0
+      while (i < values) {
+        key.setInt(0, i % 1000)
+        val h = Murmur3_x86_32.hashUnsafeWords(
+          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+        s += h
+        i += 1
+      }
+    }
+
+    Seq("off", "on").foreach { heap =>
+      benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
+        val taskMemoryManager = new TaskMemoryManager(
+          new StaticMemoryManager(
+            new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}")
+              .set("spark.memory.offHeap.size", "102400000"),
+            Long.MaxValue,
+            Long.MaxValue,
+            1),
+          0)
+        val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20)
+        val keyBytes = new Array[Byte](16)
+        val valueBytes = new Array[Byte](16)
+        val key = new UnsafeRow(1)
+        key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+        val value = new UnsafeRow(2)
+        value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+        var i = 0
+        while (i < values) {
+          key.setInt(0, i % 65536)
+          val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+          if (loc.isDefined) {
+            value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
+              loc.getValueLength)
+            value.setInt(0, value.getInt(0) + 1)
+            i += 1
+          } else {
+            loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+              value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
+          }
+        }
+      }
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Aggregate with keys:               Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+    -------------------------------------------------------------------------------
+    hash                                     662.06            79.19         1.00 X
+    BytesToBytesMap (off Heap)              2209.42            23.73         0.30 X
+    BytesToBytesMap (on Heap)               2957.68            17.73         0.22 X
+      */
+    benchmark.run()
+  }
+
+  test("benchmark") {
+    // testWholeStage(1024 * 1024 * 200)
+    // testAggregateWithKey(20 << 20)
+    // testBytesToBytesMap(1024 * 1024 * 50)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6a02c66/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 300788c..c251650 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(9, 4.5)))
   }
+
+  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
+    val df = sqlContext.range(3).groupBy("id").count().orderBy("id")
+    val plan = df.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org