You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/04/15 05:57:07 UTC

spark git commit: [SPARK-14447][SQL] Speed up TungstenAggregate w/ keys using VectorizedHashMap

Repository: spark
Updated Branches:
  refs/heads/master ff9ae61a3 -> b5c60bcdc


[SPARK-14447][SQL] Speed up TungstenAggregate w/ keys using VectorizedHashMap

## What changes were proposed in this pull request?

This patch speeds up group-by aggregates by around 3-5x by leveraging an in-memory `AggregateHashMap` (please see https://github.com/apache/spark/pull/12161), an append-only aggregate hash map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates (and fall back to the `BytesToBytesMap` if a given key isn't found).

Architecturally, it is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the key-value pairs. The index lookups in the array rely on linear probing (with a small number of maximum tries) and use an inexpensive hash function which makes it really efficient for a majority of lookups. However, using linear probing and an inexpensive hash function also makes it less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even for certain distribution of keys) and requires us to fall back on the latter for correctness.

## How was this patch tested?

    Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4
    Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz
    Aggregate w keys:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
    -------------------------------------------------------------------------------------------
    codegen = F                              2124 / 2204          9.9         101.3       1.0X
    codegen = T hashmap = F                  1198 / 1364         17.5          57.1       1.8X
    codegen = T hashmap = T                   369 /  600         56.8          17.6       5.8X

Author: Sameer Agarwal <sa...@databricks.com>

Closes #12345 from sameeragarwal/tungsten-aggregate-integration.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5c60bcd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5c60bcd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5c60bcd

Branch: refs/heads/master
Commit: b5c60bcdca3bcace607b204a6c196a5386e8a896
Parents: ff9ae61
Author: Sameer Agarwal <sa...@databricks.com>
Authored: Thu Apr 14 20:57:03 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Apr 14 20:57:03 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/WholeStageCodegen.scala |   1 +
 .../aggregate/ColumnarAggMapCodeGenerator.scala | 193 ---------------
 .../execution/aggregate/TungstenAggregate.scala | 227 ++++++++++++-----
 .../aggregate/TungstenAggregationIterator.scala |   8 +-
 .../aggregate/VectorizedHashMapGenerator.scala  | 241 +++++++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala |   9 +
 .../execution/BenchmarkWholeStageCodegen.scala  |  34 ++-
 .../hive/execution/AggregationQuerySuite.scala  |  47 ++--
 8 files changed, 479 insertions(+), 281 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 447dbe7..29acc38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -126,6 +126,7 @@ trait CodegenSupport extends SparkPlan {
         // outputVars will be used to generate the code for UnsafeRow, so we should copy them
         outputVars.map(_.copy())
       }
+
     val rowVar = if (row != null) {
       ExprCode("", "false", row)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
deleted file mode 100644
index e415dd8..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
+++ /dev/null
@@ -1,193 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
-import org.apache.spark.sql.types.StructType
-
-/**
- * This is a helper object to generate an append-only single-key/single value aggregate hash
- * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
- * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
- * TungstenAggregate to speed up aggregates w/ key.
- *
- * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
- * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
- * maximum tries) and use an inexpensive hash function which makes it really efficient for a
- * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
- * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
- * for certain distribution of keys) and requires us to fall back on the latter for correctness.
- */
-class ColumnarAggMapCodeGenerator(
-    ctx: CodegenContext,
-    generatedClassName: String,
-    groupingKeySchema: StructType,
-    bufferSchema: StructType) {
-  val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
-  val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
-  val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
-
-  def generate(): String = {
-    s"""
-       |public class $generatedClassName {
-       |${initializeAggregateHashMap()}
-       |
-       |${generateFindOrInsert()}
-       |
-       |${generateEquals()}
-       |
-       |${generateHashFunction()}
-       |}
-     """.stripMargin
-  }
-
-  private def initializeAggregateHashMap(): String = {
-    val generatedSchema: String =
-      s"""
-         |new org.apache.spark.sql.types.StructType()
-         |${(groupingKeySchema ++ bufferSchema).map(key =>
-          s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
-          .mkString("\n")};
-      """.stripMargin
-
-    s"""
-       |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
-       |  private int[] buckets;
-       |  private int numBuckets;
-       |  private int maxSteps;
-       |  private int numRows = 0;
-       |  private org.apache.spark.sql.types.StructType schema = $generatedSchema
-       |
-       |  public $generatedClassName(int capacity, double loadFactor, int maxSteps) {
-       |    assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
-       |    this.maxSteps = maxSteps;
-       |    numBuckets = (int) (capacity / loadFactor);
-       |    batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
-       |      org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
-       |    buckets = new int[numBuckets];
-       |    java.util.Arrays.fill(buckets, -1);
-       |  }
-       |
-       |  public $generatedClassName() {
-       |    new $generatedClassName(1 << 16, 0.25, 5);
-       |  }
-     """.stripMargin
-  }
-
-  /**
-   * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
-   * instance, if we have 2 long group-by keys, the generated function would be of the form:
-   *
-   * {{{
-   * private long hash(long agg_key, long agg_key1) {
-   *   return agg_key ^ agg_key1;
-   *   }
-   * }}}
-   */
-  private def generateHashFunction(): String = {
-    s"""
-       |// TODO: Improve this hash function
-       |private long hash($groupingKeySignature) {
-       |  return ${groupingKeys.map(_._2).mkString(" ^ ")};
-       |}
-     """.stripMargin
-  }
-
-  /**
-   * Generates a method that returns true if the group-by keys exist at a given index in the
-   * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
-   * have 2 long group-by keys, the generated function would be of the form:
-   *
-   * {{{
-   * private boolean equals(int idx, long agg_key, long agg_key1) {
-   *   return batch.column(0).getLong(buckets[idx]) == agg_key &&
-   *     batch.column(1).getLong(buckets[idx]) == agg_key1;
-   * }
-   * }}}
-   */
-  private def generateEquals(): String = {
-    s"""
-       |private boolean equals(int idx, $groupingKeySignature) {
-       |  return ${groupingKeys.zipWithIndex.map(k =>
-            s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
-       |}
-     """.stripMargin
-  }
-
-  /**
-   * Generates a method that returns a mutable
-   * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
-   * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
-   * generated method adds the corresponding row in the associated
-   * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
-   * have 2 long group-by keys, the generated function would be of the form:
-   *
-   * {{{
-   * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
-   *     long agg_key, long agg_key1) {
-   *   long h = hash(agg_key, agg_key1);
-   *   int step = 0;
-   *   int idx = (int) h & (numBuckets - 1);
-   *   while (step < maxSteps) {
-   *     // Return bucket index if it's either an empty slot or already contains the key
-   *     if (buckets[idx] == -1) {
-   *       batch.column(0).putLong(numRows, agg_key);
-   *       batch.column(1).putLong(numRows, agg_key1);
-   *       batch.column(2).putLong(numRows, 0);
-   *       buckets[idx] = numRows++;
-   *       return batch.getRow(buckets[idx]);
-   *     } else if (equals(idx, agg_key, agg_key1)) {
-   *       return batch.getRow(buckets[idx]);
-   *     }
-   *     idx = (idx + 1) & (numBuckets - 1);
-   *     step++;
-   *   }
-   *   // Didn't find it
-   *   return null;
-   * }
-   * }}}
-   */
-  private def generateFindOrInsert(): String = {
-    s"""
-       |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
-            groupingKeySignature}) {
-       |  long h = hash(${groupingKeys.map(_._2).mkString(", ")});
-       |  int step = 0;
-       |  int idx = (int) h & (numBuckets - 1);
-       |  while (step < maxSteps) {
-       |    // Return bucket index if it's either an empty slot or already contains the key
-       |    if (buckets[idx] == -1) {
-       |      ${groupingKeys.zipWithIndex.map(k =>
-                s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
-       |      ${bufferValues.zipWithIndex.map(k =>
-                s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);")
-                .mkString("\n")}
-       |      buckets[idx] = numRows++;
-       |      return batch.getRow(buckets[idx]);
-       |    } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
-       |      return batch.getRow(buckets[idx]);
-       |    }
-       |    idx = (idx + 1) & (numBuckets - 1);
-       |    step++;
-       |  }
-       |  // Didn't find it
-       |  return null;
-       |}
-     """.stripMargin
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 2535920..f585759 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -70,12 +70,14 @@ case class TungstenAggregate(
     }
   }
 
-  // This is for testing. We force TungstenAggregationIterator to fall back to sort-based
-  // aggregation once it has processed a given number of input rows.
-  private val testFallbackStartsAt: Option[Int] = {
+  // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
+  // map and/or the sort-based aggregation once it has processed a given number of input rows.
+  private val testFallbackStartsAt: Option[(Int, Int)] = {
     sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
       case null | "" => None
-      case fallbackStartsAt => Some(fallbackStartsAt.toInt)
+      case fallbackStartsAt =>
+        val splits = fallbackStartsAt.split(",").map(_.trim)
+        Some((splits.head.toInt, splits.last.toInt))
     }
   }
 
@@ -261,7 +263,15 @@ case class TungstenAggregate(
     .map(_.asInstanceOf[DeclarativeAggregate])
   private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
 
-  // The name for HashMap
+  // The name for Vectorized HashMap
+  private var vectorizedHashMapTerm: String = _
+
+  // We currently only enable vectorized hashmap for long key/value types and partial aggregates
+  private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled &&
+    (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) &&
+    modes.forall(mode => mode == Partial || mode == PartialMerge)
+
+  // The name for UnsafeRow HashMap
   private var hashMapTerm: String = _
   private var sorterTerm: String = _
 
@@ -437,17 +447,18 @@ case class TungstenAggregate(
     val initAgg = ctx.freshName("initAgg")
     ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
 
-    // create AggregateHashMap
-    val isAggregateHashMapEnabled: Boolean = false
-    val isAggregateHashMapSupported: Boolean =
-      (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
-    val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
-    val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
-    val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
+    vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap")
+    val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap")
+    val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName,
       groupingKeySchema, bufferSchema)
-    if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
-      ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
-        s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
+    // Create a name for iterator from vectorized HashMap
+    val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter")
+    if (isVectorizedHashMapEnabled) {
+      ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm,
+        s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();")
+      ctx.addMutableState(
+        "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>",
+        iterTermForVectorizedHashMap, "")
     }
 
     // create hashMap
@@ -465,11 +476,14 @@ case class TungstenAggregate(
     val doAgg = ctx.freshName("doAggregateWithKeys")
     ctx.addNewFunction(doAgg,
       s"""
-        ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
+        ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""}
         private void $doAgg() throws java.io.IOException {
           $hashMapTerm = $thisPlan.createHashMap();
           ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
 
+          ${if (isVectorizedHashMapEnabled) {
+              s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""}
+
           $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
         }
        """)
@@ -484,6 +498,34 @@ case class TungstenAggregate(
     // so `copyResult` should be reset to `false`.
     ctx.copyResult = false
 
+    // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow
+    def outputFromGeneratedMap: Option[String] = {
+      if (isVectorizedHashMapEnabled) {
+        val row = ctx.freshName("vectorizedHashMapRow")
+        ctx.currentVars = null
+        ctx.INPUT_ROW = row
+        var schema: StructType = groupingKeySchema
+        bufferSchema.foreach(i => schema = schema.add(i))
+        val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
+          .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
+        Option(
+          s"""
+             | while ($iterTermForVectorizedHashMap.hasNext()) {
+             |   $numOutput.add(1);
+             |   org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
+             |     (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
+             |     $iterTermForVectorizedHashMap.next();
+             |   ${generateRow.code}
+             |   ${consume(ctx, Seq.empty, {generateRow.value})}
+             |
+             |   if (shouldStop()) return;
+             | }
+             |
+             | $vectorizedHashMapTerm.close();
+           """.stripMargin)
+      } else None
+    }
+
     s"""
      if (!$initAgg) {
        $initAgg = true;
@@ -491,6 +533,8 @@ case class TungstenAggregate(
      }
 
      // output the result
+     ${outputFromGeneratedMap.getOrElse("")}
+
      while ($iterTerm.next()) {
        $numOutput.add(1);
        UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
@@ -511,10 +555,13 @@ case class TungstenAggregate(
 
     // create grouping key
     ctx.currentVars = input
-    val keyCode = GenerateUnsafeProjection.createCode(
+    val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
       ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
-    val key = keyCode.value
-    val buffer = ctx.freshName("aggBuffer")
+    val vectorizedRowKeys = ctx.generateExpressions(
+      groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+    val unsafeRowKeys = unsafeRowKeyCode.value
+    val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
+    val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer")
 
     // only have DeclarativeAggregate
     val updateExpr = aggregateExpressions.flatMap { e =>
@@ -533,56 +580,124 @@ case class TungstenAggregate(
 
     val inputAttr = aggregateBufferAttributes ++ child.output
     ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
-    ctx.INPUT_ROW = buffer
-    // TODO: support subexpression elimination
-    val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
-    val updates = evals.zipWithIndex.map { case (ev, i) =>
-      val dt = updateExpr(i).dataType
-      ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
-    }
 
-    val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
+    val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
+    incCounter) = if (testFallbackStartsAt.isDefined) {
       val countTerm = ctx.freshName("fallbackCounter")
       ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
-      (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
+      (s"$countTerm < ${testFallbackStartsAt.get._1}",
+        s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
     } else {
-      ("true", "", "")
+      ("true", "true", "", "")
     }
 
+    // We first generate code to probe and update the vectorized hash map. If the probe is
+    // successful the corresponding vectorized row buffer will hold the mutable row
+    val findOrInsertInVectorizedHashMap: Option[String] = {
+      if (isVectorizedHashMapEnabled) {
+        Option(
+          s"""
+             |if ($checkFallbackForGeneratedHashMap) {
+             |  ${vectorizedRowKeys.map(_.code).mkString("\n")}
+             |  if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) {
+             |    $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert(
+             |        ${vectorizedRowKeys.map(_.value).mkString(", ")});
+             |  }
+             |}
+         """.stripMargin)
+      } else {
+        None
+      }
+    }
+
+    val updateRowInVectorizedHashMap: Option[String] = {
+      if (isVectorizedHashMapEnabled) {
+        ctx.INPUT_ROW = vectorizedRowBuffer
+        val vectorizedRowEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
+        val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
+          val dt = updateExpr(i).dataType
+          ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
+        }
+        Option(
+          s"""
+             |// evaluate aggregate function
+             |${evaluateVariables(vectorizedRowEvals)}
+             |// update vectorized row
+             |${updateVectorizedRow.mkString("\n").trim}
+           """.stripMargin)
+      } else None
+    }
+
+    // Next, we generate code to probe and update the unsafe row hash map.
+    val findOrInsertInUnsafeRowMap: String = {
+      s"""
+         | if ($vectorizedRowBuffer == null) {
+         |   // generate grouping key
+         |   ${unsafeRowKeyCode.code.trim}
+         |   ${hashEval.code.trim}
+         |   if ($checkFallbackForBytesToBytesMap) {
+         |     // try to get the buffer from hash map
+         |     $unsafeRowBuffer =
+         |       $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
+         |   }
+         |   if ($unsafeRowBuffer == null) {
+         |     if ($sorterTerm == null) {
+         |       $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+         |     } else {
+         |       $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+         |     }
+         |     $resetCounter
+         |     // the hash map had be spilled, it should have enough memory now,
+         |     // try  to allocate buffer again.
+         |     $unsafeRowBuffer =
+         |       $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
+         |     if ($unsafeRowBuffer == null) {
+         |       // failed to allocate the first page
+         |       throw new OutOfMemoryError("No enough memory for aggregation");
+         |     }
+         |   }
+         | }
+       """.stripMargin
+    }
+
+    val updateRowInUnsafeRowMap: String = {
+      ctx.INPUT_ROW = unsafeRowBuffer
+      val unsafeRowBufferEvals =
+        updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
+      val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
+        val dt = updateExpr(i).dataType
+        ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
+      }
+      s"""
+         |// evaluate aggregate function
+         |${evaluateVariables(unsafeRowBufferEvals)}
+         |// update unsafe row buffer
+         |${updateUnsafeRowBuffer.mkString("\n").trim}
+           """.stripMargin
+    }
+
+
     // We try to do hash map based in-memory aggregation first. If there is not enough memory (the
     // hash map will return null for new key), we spill the hash map to disk to free memory, then
     // continue to do in-memory aggregation and spilling until all the rows had been processed.
     // Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
     s"""
-     // generate grouping key
-     ${keyCode.code.trim}
-     ${hashEval.code.trim}
-     UnsafeRow $buffer = null;
-     if ($checkFallback) {
-       // try to get the buffer from hash map
-       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
-     }
-     if ($buffer == null) {
-       if ($sorterTerm == null) {
-         $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
-       } else {
-         $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
-       }
-       $resetCoulter
-       // the hash map had be spilled, it should have enough memory now,
-       // try  to allocate buffer again.
-       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
-       if ($buffer == null) {
-         // failed to allocate the first page
-         throw new OutOfMemoryError("No enough memory for aggregation");
-       }
-     }
+     UnsafeRow $unsafeRowBuffer = null;
+     org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null;
+
+     ${findOrInsertInVectorizedHashMap.getOrElse("")}
+
+     $findOrInsertInUnsafeRowMap
+
      $incCounter
 
-     // evaluate aggregate function
-     ${evaluateVariables(evals)}
-     // update aggregate buffer
-     ${updates.mkString("\n").trim}
+     if ($vectorizedRowBuffer != null) {
+       // update vectorized row
+       ${updateRowInVectorizedHashMap.getOrElse("")}
+     } else {
+       // update unsafe row
+       $updateRowInUnsafeRowMap
+     }
      """
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index ce504e2..09384a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -85,7 +85,7 @@ class TungstenAggregationIterator(
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
     originalInputAttributes: Seq[Attribute],
     inputIter: Iterator[InternalRow],
-    testFallbackStartsAt: Option[Int],
+    testFallbackStartsAt: Option[(Int, Int)],
     numOutputRows: LongSQLMetric,
     dataSize: LongSQLMetric,
     spillSize: LongSQLMetric)
@@ -171,7 +171,7 @@ class TungstenAggregationIterator(
   // hashMap. If there is not enough memory, it will multiple hash-maps, spilling
   // after each becomes full then using sort to merge these spills, finally do sort
   // based aggregation.
-  private def processInputs(fallbackStartsAt: Int): Unit = {
+  private def processInputs(fallbackStartsAt: (Int, Int)): Unit = {
     if (groupingExpressions.isEmpty) {
       // If there is no grouping expressions, we can just reuse the same buffer over and over again.
       // Note that it would be better to eliminate the hash map entirely in the future.
@@ -187,7 +187,7 @@ class TungstenAggregationIterator(
         val newInput = inputIter.next()
         val groupingKey = groupingProjection.apply(newInput)
         var buffer: UnsafeRow = null
-        if (i < fallbackStartsAt) {
+        if (i < fallbackStartsAt._2) {
           buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
         }
         if (buffer == null) {
@@ -352,7 +352,7 @@ class TungstenAggregationIterator(
   /**
    * Start processing input rows.
    */
-  processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue))
+  processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue)))
 
   // If we did not switch to sort-based aggregation in processInputs,
   // we pre-load the first key-value pair from the map (to make hasNext idempotent).

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
new file mode 100644
index 0000000..395cc7a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache'
+ * for extremely fast key-value lookups while evaluating aggregates (and fall back to the
+ * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in TungstenAggregate to speed
+ * up aggregates w/ key.
+ *
+ * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
+ * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
+ * maximum tries) and use an inexpensive hash function which makes it really efficient for a
+ * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
+ * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
+ * for certain distribution of keys) and requires us to fall back on the latter for correctness. We
+ * also use a secondary columnar batch that logically projects over the original columnar batch and
+ * is equivalent to the `BytesToBytesMap` aggregate buffer.
+ *
+ * NOTE: This vectorized hash map currently doesn't support nullable keys and falls back to the
+ * `BytesToBytesMap` to store them.
+ */
+class VectorizedHashMapGenerator(
+    ctx: CodegenContext,
+    generatedClassName: String,
+    groupingKeySchema: StructType,
+    bufferSchema: StructType) {
+  val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
+  val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
+  val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
+
+  def generate(): String = {
+    s"""
+       |public class $generatedClassName {
+       |${initializeAggregateHashMap()}
+       |
+       |${generateFindOrInsert()}
+       |
+       |${generateEquals()}
+       |
+       |${generateHashFunction()}
+       |
+       |${generateRowIterator()}
+       |
+       |${generateClose()}
+       |}
+     """.stripMargin
+  }
+
+  private def initializeAggregateHashMap(): String = {
+    val generatedSchema: String =
+      s"""
+         |new org.apache.spark.sql.types.StructType()
+         |${(groupingKeySchema ++ bufferSchema).map(key =>
+          s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
+          .mkString("\n")};
+      """.stripMargin
+
+    val generatedAggBufferSchema: String =
+      s"""
+         |new org.apache.spark.sql.types.StructType()
+         |${bufferSchema.map(key =>
+        s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
+        .mkString("\n")};
+      """.stripMargin
+
+    s"""
+       |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
+       |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
+       |  private int[] buckets;
+       |  private int numBuckets;
+       |  private int maxSteps;
+       |  private int numRows = 0;
+       |  private org.apache.spark.sql.types.StructType schema = $generatedSchema
+       |  private org.apache.spark.sql.types.StructType aggregateBufferSchema =
+       |    $generatedAggBufferSchema
+       |
+       |  public $generatedClassName() {
+       |    // TODO: These should be generated based on the schema
+       |    int DEFAULT_CAPACITY = 1 << 16;
+       |    double DEFAULT_LOAD_FACTOR = 0.25;
+       |    int DEFAULT_MAX_STEPS = 2;
+       |    assert (DEFAULT_CAPACITY > 0 && ((DEFAULT_CAPACITY & (DEFAULT_CAPACITY - 1)) == 0));
+       |    this.maxSteps = DEFAULT_MAX_STEPS;
+       |    numBuckets = (int) (DEFAULT_CAPACITY / DEFAULT_LOAD_FACTOR);
+       |
+       |    batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
+       |      org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY);
+       |
+       |    // TODO: Possibly generate this projection in TungstenAggregate directly
+       |    aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(
+       |      aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY);
+       |    for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) {
+       |       aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length}));
+       |    }
+       |
+       |    buckets = new int[numBuckets];
+       |    java.util.Arrays.fill(buckets, -1);
+       |  }
+     """.stripMargin
+  }
+
+  /**
+   * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
+   * instance, if we have 2 long group-by keys, the generated function would be of the form:
+   *
+   * {{{
+   * private long hash(long agg_key, long agg_key1) {
+   *   return agg_key ^ agg_key1;
+   *   }
+   * }}}
+   */
+  private def generateHashFunction(): String = {
+    s"""
+       |// TODO: Improve this hash function
+       |private long hash($groupingKeySignature) {
+       |  return ${groupingKeys.map(_._2).mkString(" | ")};
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generates a method that returns true if the group-by keys exist at a given index in the
+   * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+   * have 2 long group-by keys, the generated function would be of the form:
+   *
+   * {{{
+   * private boolean equals(int idx, long agg_key, long agg_key1) {
+   *   return batch.column(0).getLong(buckets[idx]) == agg_key &&
+   *     batch.column(1).getLong(buckets[idx]) == agg_key1;
+   * }
+   * }}}
+   */
+  private def generateEquals(): String = {
+    s"""
+       |private boolean equals(int idx, $groupingKeySignature) {
+       |  return ${groupingKeys.zipWithIndex.map(k =>
+            s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generates a method that returns a mutable
+   * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
+   * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
+   * generated method adds the corresponding row in the associated
+   * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+   * have 2 long group-by keys, the generated function would be of the form:
+   *
+   * {{{
+   * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
+   *     long agg_key, long agg_key1) {
+   *   long h = hash(agg_key, agg_key1);
+   *   int step = 0;
+   *   int idx = (int) h & (numBuckets - 1);
+   *   while (step < maxSteps) {
+   *     // Return bucket index if it's either an empty slot or already contains the key
+   *     if (buckets[idx] == -1) {
+   *       batch.column(0).putLong(numRows, agg_key);
+   *       batch.column(1).putLong(numRows, agg_key1);
+   *       batch.column(2).putLong(numRows, 0);
+   *       buckets[idx] = numRows++;
+   *       return batch.getRow(buckets[idx]);
+   *     } else if (equals(idx, agg_key, agg_key1)) {
+   *       return batch.getRow(buckets[idx]);
+   *     }
+   *     idx = (idx + 1) & (numBuckets - 1);
+   *     step++;
+   *   }
+   *   // Didn't find it
+   *   return null;
+   * }
+   * }}}
+   */
+  private def generateFindOrInsert(): String = {
+    s"""
+       |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
+            groupingKeySignature}) {
+       |  long h = hash(${groupingKeys.map(_._2).mkString(", ")});
+       |  int step = 0;
+       |  int idx = (int) h & (numBuckets - 1);
+       |  while (step < maxSteps) {
+       |    // Return bucket index if it's either an empty slot or already contains the key
+       |    if (buckets[idx] == -1) {
+       |      ${groupingKeys.zipWithIndex.map(k =>
+                s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
+       |      ${bufferValues.zipWithIndex.map(k =>
+                s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);")
+                .mkString("\n")}
+       |      buckets[idx] = numRows++;
+       |      batch.setNumRows(numRows);
+       |      aggregateBufferBatch.setNumRows(numRows);
+       |      return aggregateBufferBatch.getRow(buckets[idx]);
+       |    } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
+       |      return aggregateBufferBatch.getRow(buckets[idx]);
+       |    }
+       |    idx = (idx + 1) & (numBuckets - 1);
+       |    step++;
+       |  }
+       |  // Didn't find it
+       |  return null;
+       |}
+     """.stripMargin
+  }
+
+  private def generateRowIterator(): String = {
+    s"""
+       |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>
+       |    rowIterator() {
+       |  return batch.rowIterator();
+       |}
+     """.stripMargin
+  }
+
+  private def generateClose(): String = {
+    s"""
+       |public void close() {
+       |  batch.close();
+       |}
+     """.stripMargin
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2f9d63c..20d9a28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -436,6 +436,13 @@ object SQLConf {
     .stringConf
     .createOptional
 
+  // TODO: This is still WIP and shouldn't be turned on without extensive test coverage
+  val COLUMNAR_AGGREGATE_MAP_ENABLED = SQLConfigBuilder("spark.sql.codegen.aggregate.map.enabled")
+    .internal()
+    .doc("When true, aggregate with keys use an in-memory columnar map to speed up execution.")
+    .booleanConf
+    .createWithDefault(false)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
     val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -560,6 +567,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
 
+  def columnarAggregateMapEnabled: Boolean = getConf(COLUMNAR_AGGREGATE_MAP_ENABLED)
+
   override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
 
   override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 352fd07..d23f19c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -153,16 +153,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
   ignore("aggregate with keys") {
     val N = 20 << 20
 
-    runBenchmark("Aggregate w keys", N) {
-      sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+    val benchmark = new Benchmark("Aggregate w keys", N)
+    def f(): Unit = sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+
+    benchmark.addCase(s"codegen = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false")
+      f()
     }
 
+    benchmark.addCase(s"codegen = T hashmap = T") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true")
+      f()
+    }
+
+    benchmark.run()
+
     /*
-      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-      Aggregate w keys:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-      -------------------------------------------------------------------------------------------
-      Aggregate w keys codegen=false           2429 / 2644          8.6         115.8       1.0X
-      Aggregate w keys codegen=true            1535 / 1571         13.7          73.2       1.6X
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    Aggregate w keys:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    codegen = F                              2219 / 2392          9.4         105.8       1.0X
+    codegen = T hashmap = F                  1330 / 1466         15.8          63.4       1.7X
+    codegen = T hashmap = T                   384 /  518         54.7          18.3       5.8X
     */
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5c60bcd/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 94fbcb7..84bb7ed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -967,27 +967,32 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite
 class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
 
   override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
-    (0 to 2).foreach { fallbackStartsAt =>
-      withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) {
-        // Create a new df to make sure its physical operator picks up
-        // spark.sql.TungstenAggregate.testFallbackStartsAt.
-        // todo: remove it?
-        val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan)
-
-        QueryTest.checkAnswer(newActual, expectedAnswer) match {
-          case Some(errorMessage) =>
-            val newErrorMessage =
-              s"""
-                |The following aggregation query failed when using TungstenAggregate with
-                |controlled fallback (it falls back to sort-based aggregation once it has processed
-                |$fallbackStartsAt input rows). The query is
-                |${actual.queryExecution}
-                |
-                |$errorMessage
-              """.stripMargin
-
-            fail(newErrorMessage)
-          case None =>
+    Seq(false, true).foreach { enableColumnarHashMap =>
+      withSQLConf("spark.sql.codegen.aggregate.map.enabled" -> enableColumnarHashMap.toString) {
+        (1 to 3).foreach { fallbackStartsAt =>
+          withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
+            s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
+            // Create a new df to make sure its physical operator picks up
+            // spark.sql.TungstenAggregate.testFallbackStartsAt.
+            // todo: remove it?
+            val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan)
+
+            QueryTest.checkAnswer(newActual, expectedAnswer) match {
+              case Some(errorMessage) =>
+                val newErrorMessage =
+                  s"""
+                    |The following aggregation query failed when using TungstenAggregate with
+                    |controlled fallback (it falls back to bytes to bytes map once it has processed
+                    |${fallbackStartsAt -1} input rows and to sort-based aggregation once it has
+                    |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
+                    |
+                    |$errorMessage
+                  """.stripMargin
+
+                fail(newErrorMessage)
+              case None => // Success
+            }
+          }
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org