You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/12/14 07:57:09 UTC

[2/2] spark git commit: [SPARK-12213][SQL] use multiple partitions for single distinct query

[SPARK-12213][SQL] use multiple partitions for single distinct query

Currently, we could generate different plans for query with single distinct (depends on spark.sql.specializeSingleDistinctAggPlanning), one works better on low cardinality columns, the other
works better for high cardinality column (default one).

This PR change to generate a single plan (three aggregations and two exchanges), which work better in both cases, then we could safely remove the flag `spark.sql.specializeSingleDistinctAggPlanning` (introduced in 1.6).

For a query like `SELECT COUNT(DISTINCT a) FROM table` will be
```
AGG-4 (count distinct)
  Shuffle to a single reducer
    Partial-AGG-3 (count distinct, no grouping)
      Partial-AGG-2 (grouping on a)
        Shuffle by a
          Partial-AGG-1 (grouping on a)
```

This PR also includes large refactor for aggregation (reduce 500+ lines of code)

cc yhuai nongli marmbrus

Author: Davies Liu <da...@databricks.com>

Closes #10228 from davies/single_distinct.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/834e7148
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/834e7148
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/834e7148

Branch: refs/heads/master
Commit: 834e71489bf560302f9d743dff669df1134e9b74
Parents: 2aecda2
Author: Davies Liu <da...@databricks.com>
Authored: Sun Dec 13 22:57:01 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Sun Dec 13 22:57:01 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/CatalystConf.scala       |   7 -
 .../analysis/DistinctAggregationRewriter.scala  |  11 +-
 .../scala/org/apache/spark/sql/SQLConf.scala    |  15 -
 .../aggregate/AggregationIterator.scala         | 417 ++++++------------
 .../aggregate/SortBasedAggregate.scala          |  29 +-
 .../SortBasedAggregationIterator.scala          |  47 +-
 .../execution/aggregate/TungstenAggregate.scala |  25 +-
 .../aggregate/TungstenAggregationIterator.scala | 439 +++----------------
 .../spark/sql/execution/aggregate/utils.scala   | 280 ++++++------
 .../hive/execution/AggregationQuerySuite.scala  | 142 +++---
 10 files changed, 422 insertions(+), 990 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 7c2b8a9..2c7c58e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst
 
 private[spark] trait CatalystConf {
   def caseSensitiveAnalysis: Boolean
-
-  protected[spark] def specializeSingleDistinctAggPlanning: Boolean
 }
 
 /**
@@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf {
   override def caseSensitiveAnalysis: Boolean = {
     throw new UnsupportedOperationException
   }
-
-  protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
-    throw new UnsupportedOperationException
-  }
 }
 
 /** A CatalystConf that can be used for local testing. */
 case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
-  protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 9c78f6d..4e7d134 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
       .filter(_.isDistinct)
       .groupBy(_.aggregateFunction.children.toSet)
 
-    val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
-      // When the flag is set to specialize single distinct agg planning,
-      // we will rely on our Aggregation strategy to handle queries with a single
-      // distinct column.
-      distinctAggGroups.size > 1
-    } else {
-      distinctAggGroups.size >= 1
-    }
-    if (shouldRewrite) {
+    // Aggregation strategy can handle the query with single distinct
+    if (distinctAggGroups.size > 1) {
       // Create the attributes for the grouping id and the group by clause.
       val gid = new AttributeReference("gid", IntegerType, false)()
       val groupByMap = a.groupingExpressions.collect {

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 58adf64..3d81926 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -449,18 +449,6 @@ private[spark] object SQLConf {
     doc = "When true, we could use `datasource`.`path` as table in SQL query"
   )
 
-  val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING =
-    booleanConf("spark.sql.specializeSingleDistinctAggPlanning",
-      defaultValue = Some(false),
-      isPublic = false,
-      doc = "When true, if a query only has a single distinct column and it has " +
-        "grouping expressions, we will use our planner rule to handle this distinct " +
-        "column (other cases are handled by DistinctAggregationRewriter). " +
-        "When false, we will always use DistinctAggregationRewriter to plan " +
-        "aggregation queries with DISTINCT keyword. This is an internal flag that is " +
-        "used to benchmark the performance impact of using DistinctAggregationRewriter to " +
-        "plan aggregation queries with a single distinct column.")
-
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
     val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
 
   private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
 
-  protected[spark] override def specializeSingleDistinctAggPlanning: Boolean =
-    getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING)
-
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 008478a..0c74df0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -17,15 +17,15 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 
-import scala.collection.mutable.ArrayBuffer
-
 /**
- * The base class of [[SortBasedAggregationIterator]].
+ * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]].
  * It mainly contains two parts:
  * 1. It initializes aggregate functions.
  * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
@@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer
  *    is used to generate result.
  */
 abstract class AggregationIterator(
-    groupingKeyAttributes: Seq[Attribute],
-    valueAttributes: Seq[Attribute],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression],
-    completeAggregateAttributes: Seq[Attribute],
+    groupingExpressions: Seq[NamedExpression],
+    inputAttributes: Seq[Attribute],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    outputsUnsafeRows: Boolean)
-  extends Iterator[InternalRow] with Logging {
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection))
+  extends Iterator[UnsafeRow] with Logging {
 
   ///////////////////////////////////////////////////////////////////////////
   // Initializing functions.
   ///////////////////////////////////////////////////////////////////////////
 
-  // An Seq of all AggregateExpressions.
-  // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
-  // are at the beginning of the allAggregateExpressions.
-  protected val allAggregateExpressions =
-    nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
-  require(
-    allAggregateExpressions.map(_.mode).distinct.length <= 2,
-    s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.")
-
   /**
-   * The distinct modes of AggregateExpressions. Right now, we can handle the following mode:
-   *  - Partial-only: all AggregateExpressions have the mode of Partial;
-   *  - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge);
-   *  - Final-only: all AggregateExpressions have the mode of Final;
-   *  - Final-Complete: some AggregateExpressions have the mode of Final and
-   *    others have the mode of Complete;
-   *  - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions
-   *    with mode Complete in completeAggregateExpressions; and
-   *  - Grouping-only: there is no AggregateExpression.
-   */
-  protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) =
-    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
-      completeAggregateExpressions.map(_.mode).distinct.headOption
+    * The following combinations of AggregationMode are supported:
+    * - Partial
+    * - PartialMerge (for single distinct)
+    * - Partial and PartialMerge (for single distinct)
+    * - Final
+    * - Complete (for SortBasedAggregate with functions that does not support Partial)
+    * - Final and Complete (currently not used)
+    *
+    * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
+    * could have a flag to tell it's final or not.
+    */
+  {
+    val modes = aggregateExpressions.map(_.mode).distinct.toSet
+    require(modes.size <= 2,
+      s"$aggregateExpressions are not supported because they have more than 2 distinct modes.")
+    require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)),
+      s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.")
+  }
 
   // Initialize all AggregateFunctions by binding references if necessary,
   // and set inputBufferOffset and mutableBufferOffset.
-  protected val allAggregateFunctions: Array[AggregateFunction] = {
+  protected def initializeAggregateFunctions(
+      expressions: Seq[AggregateExpression],
+      startingInputBufferOffset: Int): Array[AggregateFunction] = {
     var mutableBufferOffset = 0
-    var inputBufferOffset: Int = initialInputBufferOffset
-    val functions = new Array[AggregateFunction](allAggregateExpressions.length)
+    var inputBufferOffset: Int = startingInputBufferOffset
+    val functions = new Array[AggregateFunction](expressions.length)
     var i = 0
-    while (i < allAggregateExpressions.length) {
-      val func = allAggregateExpressions(i).aggregateFunction
-      val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match {
+    while (i < expressions.length) {
+      val func = expressions(i).aggregateFunction
+      val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
         case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
           // We need to create BoundReferences if the function is not an
           // expression-based aggregate function (it does not support code-gen) and the mode of
           // this function is Partial or Complete because we will call eval of this
           // function's children in the update method of this aggregate function.
           // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, valueAttributes)
+          BindReferences.bindReference(func, inputAttributes)
         case _ =>
           // We only need to set inputBufferOffset for aggregate functions with mode
           // PartialMerge and Final.
@@ -117,15 +111,18 @@ abstract class AggregationIterator(
     functions
   }
 
+  protected val aggregateFunctions: Array[AggregateFunction] =
+    initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset)
+
   // Positions of those imperative aggregate functions in allAggregateFunctions.
   // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
   // func2 and func3 are imperative aggregate functions.
   // ImperativeAggregateFunctionPositions will be [1, 2].
-  private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
+  protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
     val positions = new ArrayBuffer[Int]()
     var i = 0
-    while (i < allAggregateFunctions.length) {
-      allAggregateFunctions(i) match {
+    while (i < aggregateFunctions.length) {
+      aggregateFunctions(i) match {
         case agg: DeclarativeAggregate =>
         case _ => positions += i
       }
@@ -134,17 +131,9 @@ abstract class AggregationIterator(
     positions.toArray
   }
 
-  // All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
-  private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] =
-    allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
-
-  // All imperative aggregate functions with mode Partial, PartialMerge, or Final.
-  private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
-    nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
   // The projection used to initialize buffer values for all expression-based aggregates.
-  private[this] val expressionAggInitialProjection = {
-    val initExpressions = allAggregateFunctions.flatMap {
+  protected[this] val expressionAggInitialProjection = {
+    val initExpressions = aggregateFunctions.flatMap {
       case ae: DeclarativeAggregate => ae.initialValues
       // For the positions corresponding to imperative aggregate functions, we'll use special
       // no-op expressions which are ignored during projection code-generation.
@@ -154,248 +143,112 @@ abstract class AggregationIterator(
   }
 
   // All imperative AggregateFunctions.
-  private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+  protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
     allImperativeAggregateFunctionPositions
-      .map(allAggregateFunctions)
+      .map(aggregateFunctions)
       .map(_.asInstanceOf[ImperativeAggregate])
 
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods and fields used by sub-classes.
-  ///////////////////////////////////////////////////////////////////////////
-
   // Initializing functions used to process a row.
-  protected val processRow: (MutableRow, InternalRow) => Unit = {
-    val rowToBeProcessed = new JoinedRow
-    val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
-    aggregationMode match {
-      // Partial-only
-      case (Some(Partial), None) =>
-        val updateExpressions = nonCompleteAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.updateExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        val expressionAggUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
-        (currentBuffer: MutableRow, row: InternalRow) => {
-          expressionAggUpdateProjection.target(currentBuffer)
-          // Process all expression-based aggregate functions.
-          expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row))
-          // Process all imperative aggregate functions.
-          var i = 0
-          while (i < nonCompleteImperativeAggregateFunctions.length) {
-            nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // PartialMerge-only or Final-only
-      case (Some(PartialMerge), None) | (Some(Final), None) =>
-        val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) {
-          // If initialInputBufferOffset, the input value does not contain
-          // grouping keys.
-          // This part is pretty hacky.
-          allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq
-        } else {
-          groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
-        }
-        // val inputAggregationBufferSchema =
-        //  groupingKeyAttributes ++
-        //    allAggregateFunctions.flatMap(_.cloneBufferAttributes)
-        val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.mergeExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        // This projection is used to merge buffer values for all expression-based aggregates.
-        val expressionAggMergeProjection =
-          newMutableProjection(
-            mergeExpressions,
-            aggregationBufferSchema ++ inputAggregationBufferSchema)()
-
-        (currentBuffer: MutableRow, row: InternalRow) => {
-          // Process all expression-based aggregate functions.
-          expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
-          // Process all imperative aggregate functions.
-          var i = 0
-          while (i < nonCompleteImperativeAggregateFunctions.length) {
-            nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // Final-Complete
-      case (Some(Final), Some(Complete)) =>
-        val completeAggregateFunctions: Array[AggregateFunction] =
-          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
-        // All imperative aggregate functions with mode Complete.
-        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
-        // The first initialInputBufferOffset values of the input aggregation buffer is
-        // for grouping expressions and distinct columns.
-        val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
-
-        val completeOffsetExpressions =
-          Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-        // We do not touch buffer values of aggregate functions with the Final mode.
-        val finalOffsetExpressions =
-          Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-
-        val mergeInputSchema =
-          aggregationBufferSchema ++
-            groupingAttributesAndDistinctColumns ++
-            nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes)
-        val mergeExpressions =
-          nonCompleteAggregateFunctions.flatMap {
-            case ae: DeclarativeAggregate => ae.mergeExpressions
-            case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-          } ++ completeOffsetExpressions
-        val finalExpressionAggMergeProjection =
-          newMutableProjection(mergeExpressions, mergeInputSchema)()
-
-        val updateExpressions =
-          finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
-            case ae: DeclarativeAggregate => ae.updateExpressions
-            case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-          }
-        val completeExpressionAggUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
-        (currentBuffer: MutableRow, row: InternalRow) => {
-          val input = rowToBeProcessed(currentBuffer, row)
-          // For all aggregate functions with mode Complete, update buffers.
-          completeExpressionAggUpdateProjection.target(currentBuffer)(input)
-          var i = 0
-          while (i < completeImperativeAggregateFunctions.length) {
-            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
-          }
-
-          // For all aggregate functions with mode Final, merge buffers.
-          finalExpressionAggMergeProjection.target(currentBuffer)(input)
-          i = 0
-          while (i < nonCompleteImperativeAggregateFunctions.length) {
-            nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
-            i += 1
+  protected def generateProcessRow(
+      expressions: Seq[AggregateExpression],
+      functions: Seq[AggregateFunction],
+      inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = {
+    val joinedRow = new JoinedRow
+    if (expressions.nonEmpty) {
+      val mergeExpressions = functions.zipWithIndex.flatMap {
+        case (ae: DeclarativeAggregate, i) =>
+          expressions(i).mode match {
+            case Partial | Complete => ae.updateExpressions
+            case PartialMerge | Final => ae.mergeExpressions
           }
-        }
-
-      // Complete-only
-      case (None, Some(Complete)) =>
-        val completeAggregateFunctions: Array[AggregateFunction] =
-          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
-        // All imperative aggregate functions with mode Complete.
-        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
-        val updateExpressions =
-          completeAggregateFunctions.flatMap {
-            case ae: DeclarativeAggregate => ae.updateExpressions
-            case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-          }
-        val completeExpressionAggUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
-        (currentBuffer: MutableRow, row: InternalRow) => {
-          val input = rowToBeProcessed(currentBuffer, row)
-          // For all aggregate functions with mode Complete, update buffers.
-          completeExpressionAggUpdateProjection.target(currentBuffer)(input)
-          var i = 0
-          while (i < completeImperativeAggregateFunctions.length) {
-            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
+        case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+      }
+      val updateFunctions = functions.zipWithIndex.collect {
+        case (ae: ImperativeAggregate, i) =>
+          expressions(i).mode match {
+            case Partial | Complete =>
+              (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row)
+            case PartialMerge | Final =>
+              (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row)
           }
+      }
+      // This projection is used to merge buffer values for all expression-based aggregates.
+      val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes)
+      val updateProjection =
+        newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)()
+
+      (currentBuffer: MutableRow, row: InternalRow) => {
+        // Process all expression-based aggregate functions.
+        updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+        // Process all imperative aggregate functions.
+        var i = 0
+        while (i < updateFunctions.length) {
+          updateFunctions(i)(currentBuffer, row)
+          i += 1
         }
-
+      }
+    } else {
       // Grouping only.
-      case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {}
-
-      case other =>
-        sys.error(
-          s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
-            s"support evaluate modes $other in this iterator.")
+      (currentBuffer: MutableRow, row: InternalRow) => {}
     }
   }
 
-  // Initializing the function used to generate the output row.
-  protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
-    val rowToBeEvaluated = new JoinedRow
-    val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
-    val mutableOutput = if (outputsUnsafeRows) {
-      UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
-    } else {
-      safeOutputRow
-    }
-
-    aggregationMode match {
-      // Partial-only or PartialMerge-only: every output row is basically the values of
-      // the grouping expressions and the corresponding aggregation buffer.
-      case (Some(Partial), None) | (Some(PartialMerge), None) =>
-        // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
-        // support generic getter), we create a mutable projection to output the
-        // JoinedRow(currentGroupingKey, currentBuffer)
-        val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes)
-        val resultProjection =
-          newMutableProjection(
-            groupingKeyAttributes ++ bufferSchema,
-            groupingKeyAttributes ++ bufferSchema)()
-        resultProjection.target(mutableOutput)
-
-        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
-          resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer))
-          // rowToBeEvaluated(currentGroupingKey, currentBuffer)
-        }
+  protected val processRow: (MutableRow, InternalRow) => Unit =
+    generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes)
 
-      // Final-only, Complete-only and Final-Complete: every output row contains values representing
-      // resultExpressions.
-      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
-        val bufferSchemata =
-          allAggregateFunctions.flatMap(_.aggBufferAttributes)
-        val evalExpressions = allAggregateFunctions.map {
-          case ae: DeclarativeAggregate => ae.evaluateExpression
-          case agg: AggregateFunction => NoOp
-        }
-        val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
-        val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
-        // TODO: Use unsafe row.
-        val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
-        expressionAggEvalProjection.target(aggregateResult)
-        val resultProjection =
-          newMutableProjection(
-            resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
-        resultProjection.target(mutableOutput)
+  protected val groupingProjection: UnsafeProjection =
+    UnsafeProjection.create(groupingExpressions, inputAttributes)
+  protected val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
-        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
-          // Generate results for all expression-based aggregate functions.
-          expressionAggEvalProjection(currentBuffer)
-          // Generate results for all imperative aggregate functions.
-          var i = 0
-          while (i < allImperativeAggregateFunctions.length) {
-            aggregateResult.update(
-              allImperativeAggregateFunctionPositions(i),
-              allImperativeAggregateFunctions(i).eval(currentBuffer))
-            i += 1
-          }
-          resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
+  // Initializing the function used to generate the output row.
+  protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = {
+    val joinedRow = new JoinedRow
+    val modes = aggregateExpressions.map(_.mode).distinct
+    val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+    if (modes.contains(Final) || modes.contains(Complete)) {
+      val evalExpressions = aggregateFunctions.map {
+        case ae: DeclarativeAggregate => ae.evaluateExpression
+        case agg: AggregateFunction => NoOp
+      }
+      val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType))
+      val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
+      expressionAggEvalProjection.target(aggregateResult)
+
+      val resultProjection =
+        UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
+
+      (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+        // Generate results for all expression-based aggregate functions.
+        expressionAggEvalProjection(currentBuffer)
+        // Generate results for all imperative aggregate functions.
+        var i = 0
+        while (i < allImperativeAggregateFunctions.length) {
+          aggregateResult.update(
+            allImperativeAggregateFunctionPositions(i),
+            allImperativeAggregateFunctions(i).eval(currentBuffer))
+          i += 1
         }
-
+        resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+      }
+    } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+      val resultProjection = UnsafeProjection.create(
+        groupingAttributes ++ bufferAttributes,
+        groupingAttributes ++ bufferAttributes)
+      (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+        resultProjection(joinedRow(currentGroupingKey, currentBuffer))
+      }
+    } else {
       // Grouping-only: we only output values of grouping expressions.
-      case (None, None) =>
-        val resultProjection =
-          newMutableProjection(resultExpressions, groupingKeyAttributes)()
-        resultProjection.target(mutableOutput)
-
-        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
-          resultProjection(currentGroupingKey)
-        }
-
-      case other =>
-        sys.error(
-          s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
-            s"support evaluate modes $other in this iterator.")
+      val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
+      (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+        resultProjection(currentGroupingKey)
+      }
     }
   }
 
+  protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow =
+    generateResultProjection()
+
   /** Initializes buffer values for all aggregate functions. */
   protected def initializeBuffer(buffer: MutableRow): Unit = {
     expressionAggInitialProjection.target(buffer)(EmptyRow)
@@ -405,10 +258,4 @@ abstract class AggregationIterator(
       i += 1
     }
   }
-
-  /**
-   * Creates a new aggregation buffer and initializes buffer values
-   * for all aggregate functions.
-   */
-  protected def newBuffer: MutableRow
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index ee98245..c5470a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
 case class SortBasedAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
     groupingExpressions: Seq[NamedExpression],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression],
-    completeAggregateAttributes: Seq[Attribute],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
@@ -42,10 +40,8 @@ case class SortBasedAggregate(
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  override def outputsUnsafeRows: Boolean = false
-
+  override def outputsUnsafeRows: Boolean = true
   override def canProcessUnsafeRows: Boolean = false
-
   override def canProcessSafeRows: Boolean = true
 
   override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -76,31 +72,24 @@ case class SortBasedAggregate(
       if (!hasInput && groupingExpressions.nonEmpty) {
         // This is a grouped aggregate and the input iterator is empty,
         // so return an empty iterator.
-        Iterator[InternalRow]()
+        Iterator[UnsafeRow]()
       } else {
-        val groupingKeyProjection =
-          UnsafeProjection.create(groupingExpressions, child.output)
-
         val outputIter = new SortBasedAggregationIterator(
-          groupingKeyProjection,
-          groupingExpressions.map(_.toAttribute),
+          groupingExpressions,
           child.output,
           iter,
-          nonCompleteAggregateExpressions,
-          nonCompleteAggregateAttributes,
-          completeAggregateExpressions,
-          completeAggregateAttributes,
+          aggregateExpressions,
+          aggregateAttributes,
           initialInputBufferOffset,
           resultExpressions,
           newMutableProjection,
-          outputsUnsafeRows,
           numInputRows,
           numOutputRows)
         if (!hasInput && groupingExpressions.isEmpty) {
           // There is no input and there is no grouping expressions.
           // We need to output a single row as the output.
           numOutputRows += 1
-          Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+          Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
         } else {
           outputIter
         }
@@ -109,7 +98,7 @@ case class SortBasedAggregate(
   }
 
   override def simpleString: String = {
-    val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+    val allAggregateExpressions = aggregateExpressions
 
     val keyString = groupingExpressions.mkString("[", ",", "]")
     val functionString = allAggregateExpressions.mkString("[", ",", "]")

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index fe5c319..ac920aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric
 
 /**
  * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
- * sorted by values of [[groupingKeyAttributes]].
+ * sorted by values of [[groupingExpressions]].
  */
 class SortBasedAggregationIterator(
-    groupingKeyProjection: InternalRow => InternalRow,
-    groupingKeyAttributes: Seq[Attribute],
+    groupingExpressions: Seq[NamedExpression],
     valueAttributes: Seq[Attribute],
     inputIterator: Iterator[InternalRow],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression],
-    completeAggregateAttributes: Seq[Attribute],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    outputsUnsafeRows: Boolean,
     numInputRows: LongSQLMetric,
     numOutputRows: LongSQLMetric)
   extends AggregationIterator(
-    groupingKeyAttributes,
+    groupingExpressions,
     valueAttributes,
-    nonCompleteAggregateExpressions,
-    nonCompleteAggregateAttributes,
-    completeAggregateExpressions,
-    completeAggregateAttributes,
+    aggregateExpressions,
+    aggregateAttributes,
     initialInputBufferOffset,
     resultExpressions,
-    newMutableProjection,
-    outputsUnsafeRows) {
-
-  override protected def newBuffer: MutableRow = {
-    val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
+    newMutableProjection) {
+
+  /**
+    * Creates a new aggregation buffer and initializes buffer values
+    * for all aggregate functions.
+    */
+  private def newBuffer: MutableRow = {
+    val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
     val bufferRowSize: Int = bufferSchema.length
 
     val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
@@ -76,10 +73,10 @@ class SortBasedAggregationIterator(
   ///////////////////////////////////////////////////////////////////////////
 
   // The partition key of the current partition.
-  private[this] var currentGroupingKey: InternalRow = _
+  private[this] var currentGroupingKey: UnsafeRow = _
 
   // The partition key of next partition.
-  private[this] var nextGroupingKey: InternalRow = _
+  private[this] var nextGroupingKey: UnsafeRow = _
 
   // The first row of next partition.
   private[this] var firstRowInNextGroup: InternalRow = _
@@ -94,7 +91,7 @@ class SortBasedAggregationIterator(
     if (inputIterator.hasNext) {
       initializeBuffer(sortBasedAggregationBuffer)
       val inputRow = inputIterator.next()
-      nextGroupingKey = groupingKeyProjection(inputRow).copy()
+      nextGroupingKey = groupingProjection(inputRow).copy()
       firstRowInNextGroup = inputRow.copy()
       numInputRows += 1
       sortedInputHasNewGroup = true
@@ -120,7 +117,7 @@ class SortBasedAggregationIterator(
     while (!findNextPartition && inputIterator.hasNext) {
       // Get the grouping key.
       val currentRow = inputIterator.next()
-      val groupingKey = groupingKeyProjection(currentRow)
+      val groupingKey = groupingProjection(currentRow)
       numInputRows += 1
 
       // Check if the current row belongs the current input row.
@@ -146,7 +143,7 @@ class SortBasedAggregationIterator(
 
   override final def hasNext: Boolean = sortedInputHasNewGroup
 
-  override final def next(): InternalRow = {
+  override final def next(): UnsafeRow = {
     if (hasNext) {
       // Process the current group.
       processCurrentSortedGroup()
@@ -162,8 +159,8 @@ class SortBasedAggregationIterator(
     }
   }
 
-  def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
+  def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
     initializeBuffer(sortBasedAggregationBuffer)
-    generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
+    generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 920de61..b8849c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType
 case class TungstenAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
     groupingExpressions: Seq[NamedExpression],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression],
-    completeAggregateAttributes: Seq[Attribute],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
   extends UnaryNode {
 
   private[this] val aggregateBufferAttributes = {
-    (nonCompleteAggregateExpressions ++ completeAggregateExpressions)
-      .flatMap(_.aggregateFunction.aggBufferAttributes)
+    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
   }
 
-  require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
+  require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))
 
   override private[sql] lazy val metrics = Map(
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
@@ -53,9 +50,7 @@ case class TungstenAggregate(
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
 
   override def outputsUnsafeRows: Boolean = true
-
   override def canProcessUnsafeRows: Boolean = true
-
   override def canProcessSafeRows: Boolean = true
 
   override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -94,10 +89,8 @@ case class TungstenAggregate(
         val aggregationIterator =
           new TungstenAggregationIterator(
             groupingExpressions,
-            nonCompleteAggregateExpressions,
-            nonCompleteAggregateAttributes,
-            completeAggregateExpressions,
-            completeAggregateAttributes,
+            aggregateExpressions,
+            aggregateAttributes,
             initialInputBufferOffset,
             resultExpressions,
             newMutableProjection,
@@ -119,7 +112,7 @@ case class TungstenAggregate(
   }
 
   override def simpleString: String = {
-    val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+    val allAggregateExpressions = aggregateExpressions
 
     testFallbackStartsAt match {
       case None =>
@@ -135,9 +128,7 @@ case class TungstenAggregate(
 }
 
 object TungstenAggregate {
-  def supportsAggregate(
-    groupingExpressions: Seq[Expression],
-    aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+  def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
     val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
     UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/834e7148/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 0439144..582fdbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -17,17 +17,15 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
 
 /**
  * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
@@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType
  *
  * @param groupingExpressions
  *   expressions for grouping keys
- * @param nonCompleteAggregateExpressions
+ * @param aggregateExpressions
  * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]],
  * [[PartialMerge]], or [[Final]].
- * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
+ * @param aggregateAttributes the attributes of the aggregateExpressions'
  *   outputs when they are stored in the final aggregation buffer.
- * @param completeAggregateExpressions
- * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]].
- * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
- *   when they are stored in the final aggregation buffer.
  * @param resultExpressions
  *   expressions for generating output rows.
  * @param newMutableProjection
@@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType
  */
 class TungstenAggregationIterator(
     groupingExpressions: Seq[NamedExpression],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression],
-    completeAggregateAttributes: Seq[Attribute],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
@@ -97,378 +89,62 @@ class TungstenAggregationIterator(
     numOutputRows: LongSQLMetric,
     dataSize: LongSQLMetric,
     spillSize: LongSQLMetric)
-  extends Iterator[UnsafeRow] with Logging {
+  extends AggregationIterator(
+    groupingExpressions,
+    originalInputAttributes,
+    aggregateExpressions,
+    aggregateAttributes,
+    initialInputBufferOffset,
+    resultExpressions,
+    newMutableProjection) with Logging {
 
   ///////////////////////////////////////////////////////////////////////////
   // Part 1: Initializing aggregate functions.
   ///////////////////////////////////////////////////////////////////////////
 
-  // A Seq containing all AggregateExpressions.
-  // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
-  // are at the beginning of the allAggregateExpressions.
-  private[this] val allAggregateExpressions: Seq[AggregateExpression] =
-    nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
-  // Check to make sure we do not have more than three modes in our AggregateExpressions.
-  // If we have, users are hitting a bug and we throw an IllegalStateException.
-  if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
-    throw new IllegalStateException(
-      s"$allAggregateExpressions should have no more than 2 kinds of modes.")
-  }
-
   // Remember spill data size of this task before execute this operator so that we can
   // figure out how many bytes we spilled for this operator.
   private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
 
-  //
-  // The modes of AggregateExpressions. Right now, we can handle the following mode:
-  //  - Partial-only:
-  //      All AggregateExpressions have the mode of Partial.
-  //      For this case, aggregationMode is (Some(Partial), None).
-  //  - PartialMerge-only:
-  //      All AggregateExpressions have the mode of PartialMerge).
-  //      For this case, aggregationMode is (Some(PartialMerge), None).
-  //  - Final-only:
-  //      All AggregateExpressions have the mode of Final.
-  //      For this case, aggregationMode is (Some(Final), None).
-  //  - Final-Complete:
-  //      Some AggregateExpressions have the mode of Final and
-  //      others have the mode of Complete. For this case,
-  //      aggregationMode is (Some(Final), Some(Complete)).
-  //  - Complete-only:
-  //      nonCompleteAggregateExpressions is empty and we have AggregateExpressions
-  //      with mode Complete in completeAggregateExpressions. For this case,
-  //      aggregationMode is (None, Some(Complete)).
-  //  - Grouping-only:
-  //      There is no AggregateExpression. For this case, AggregationMode is (None,None).
-  //
-  private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = {
-    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
-      completeAggregateExpressions.map(_.mode).distinct.headOption
-  }
-
-  // Initialize all AggregateFunctions by binding references, if necessary,
-  // and setting inputBufferOffset and mutableBufferOffset.
-  private def initializeAllAggregateFunctions(
-      startingInputBufferOffset: Int): Array[AggregateFunction] = {
-    var mutableBufferOffset = 0
-    var inputBufferOffset: Int = startingInputBufferOffset
-    val functions = new Array[AggregateFunction](allAggregateExpressions.length)
-    var i = 0
-    while (i < allAggregateExpressions.length) {
-      val func = allAggregateExpressions(i).aggregateFunction
-      val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length
-      // We need to use this mode instead of func.mode in order to handle aggregation mode switching
-      // when switching to sort-based aggregation:
-      val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2
-      val funcWithBoundReferences = mode match {
-        case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] =>
-          // We need to create BoundReferences if the function is not an
-          // expression-based aggregate function (it does not support code-gen) and the mode of
-          // this function is Partial or Complete because we will call eval of this
-          // function's children in the update method of this aggregate function.
-          // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, originalInputAttributes)
-        case _ =>
-          // We only need to set inputBufferOffset for aggregate functions with mode
-          // PartialMerge and Final.
-          val updatedFunc = func match {
-            case function: ImperativeAggregate =>
-              function.withNewInputAggBufferOffset(inputBufferOffset)
-            case function => function
-          }
-          inputBufferOffset += func.aggBufferSchema.length
-          updatedFunc
-      }
-      val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
-        case function: ImperativeAggregate =>
-          // Set mutableBufferOffset for this function. It is important that setting
-          // mutableBufferOffset happens after all potential bindReference operations
-          // because bindReference will create a new instance of the function.
-          function.withNewMutableAggBufferOffset(mutableBufferOffset)
-        case function => function
-      }
-      mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
-      functions(i) = funcWithUpdatedAggBufferOffset
-      i += 1
-    }
-    functions
-  }
-
-  private[this] var allAggregateFunctions: Array[AggregateFunction] =
-    initializeAllAggregateFunctions(initialInputBufferOffset)
-
-  // Positions of those imperative aggregate functions in allAggregateFunctions.
-  // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and
-  // func2 and func3 are imperative aggregate functions. Then
-  // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be
-  // updated when falling back to sort-based aggregation because the positions of the aggregate
-  // functions do not change in that case.
-  private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
-    val positions = new ArrayBuffer[Int]()
-    var i = 0
-    while (i < allAggregateFunctions.length) {
-      allAggregateFunctions(i) match {
-        case agg: DeclarativeAggregate =>
-        case _ => positions += i
-      }
-      i += 1
-    }
-    positions.toArray
-  }
-
   ///////////////////////////////////////////////////////////////////////////
   // Part 2: Methods and fields used by setting aggregation buffer values,
   //         processing input rows from inputIter, and generating output
   //         rows.
   ///////////////////////////////////////////////////////////////////////////
 
-  // The projection used to initialize buffer values for all expression-based aggregates.
-  // Note that this projection does not need to be updated when switching to sort-based aggregation
-  // because the schema of empty aggregation buffers does not change in that case.
-  private[this] val expressionAggInitialProjection: MutableProjection = {
-    val initExpressions = allAggregateFunctions.flatMap {
-      case ae: DeclarativeAggregate => ae.initialValues
-      // For the positions corresponding to imperative aggregate functions, we'll use special
-      // no-op expressions which are ignored during projection code-generation.
-      case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
-    }
-    newMutableProjection(initExpressions, Nil)()
-  }
-
   // Creates a new aggregation buffer and initializes buffer values.
-  // This function should be only called at most three times (when we create the hash map,
-  // when we switch to sort-based aggregation, and when we create the re-used buffer for
-  // sort-based aggregation).
+  // This function should be only called at most two times (when we create the hash map,
+  // and when we create the re-used buffer for sort-based aggregation).
   private def createNewAggregationBuffer(): UnsafeRow = {
-    val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
+    val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
     val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType))
       .apply(new GenericMutableRow(bufferSchema.length))
     // Initialize declarative aggregates' buffer values
     expressionAggInitialProjection.target(buffer)(EmptyRow)
     // Initialize imperative aggregates' buffer values
-    allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
+    aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
     buffer
   }
 
-  // Creates a function used to process a row based on the given inputAttributes.
-  private def generateProcessRow(
-      inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
-
-    val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
-    val joinedRow = new JoinedRow()
-
-    aggregationMode match {
-      // Partial-only
-      case (Some(Partial), None) =>
-        val updateExpressions = allAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.updateExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
-          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
-        val expressionAggUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
-        (currentBuffer: UnsafeRow, row: InternalRow) => {
-          expressionAggUpdateProjection.target(currentBuffer)
-          // Process all expression-based aggregate functions.
-          expressionAggUpdateProjection(joinedRow(currentBuffer, row))
-          // Process all imperative aggregate functions
-          var i = 0
-          while (i < imperativeAggregateFunctions.length) {
-            imperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // PartialMerge-only or Final-only
-      case (Some(PartialMerge), None) | (Some(Final), None) =>
-        val mergeExpressions = allAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.mergeExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
-          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
-        // This projection is used to merge buffer values for all expression-based aggregates.
-        val expressionAggMergeProjection =
-          newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
-        (currentBuffer: UnsafeRow, row: InternalRow) => {
-          // Process all expression-based aggregate functions.
-          expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
-          // Process all imperative aggregate functions.
-          var i = 0
-          while (i < imperativeAggregateFunctions.length) {
-            imperativeAggregateFunctions(i).merge(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // Final-Complete
-      case (Some(Final), Some(Complete)) =>
-        val completeAggregateFunctions: Array[AggregateFunction] =
-          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
-        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-        val nonCompleteAggregateFunctions: Array[AggregateFunction] =
-          allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
-        val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
-        val completeOffsetExpressions =
-          Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-        val mergeExpressions =
-          nonCompleteAggregateFunctions.flatMap {
-            case ae: DeclarativeAggregate => ae.mergeExpressions
-            case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-          } ++ completeOffsetExpressions
-        val finalMergeProjection =
-          newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
-        // We do not touch buffer values of aggregate functions with the Final mode.
-        val finalOffsetExpressions =
-          Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-        val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.updateExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        val completeUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
-        (currentBuffer: UnsafeRow, row: InternalRow) => {
-          val input = joinedRow(currentBuffer, row)
-          // For all aggregate functions with mode Complete, update buffers.
-          completeUpdateProjection.target(currentBuffer)(input)
-          var i = 0
-          while (i < completeImperativeAggregateFunctions.length) {
-            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
-          }
-
-          // For all aggregate functions with mode Final, merge buffer values in row to
-          // currentBuffer.
-          finalMergeProjection.target(currentBuffer)(input)
-          i = 0
-          while (i < nonCompleteImperativeAggregateFunctions.length) {
-            nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // Complete-only
-      case (None, Some(Complete)) =>
-        val completeAggregateFunctions: Array[AggregateFunction] =
-          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
-        // All imperative aggregate functions with mode Complete.
-        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
-        val updateExpressions = completeAggregateFunctions.flatMap {
-          case ae: DeclarativeAggregate => ae.updateExpressions
-          case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
-        }
-        val completeExpressionAggUpdateProjection =
-          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
-        (currentBuffer: UnsafeRow, row: InternalRow) => {
-          // For all aggregate functions with mode Complete, update buffers.
-          completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
-          var i = 0
-          while (i < completeImperativeAggregateFunctions.length) {
-            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
-            i += 1
-          }
-        }
-
-      // Grouping only.
-      case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
-
-      case other =>
-        throw new IllegalStateException(
-          s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
-    }
-  }
-
   // Creates a function used to generate output rows.
-  private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
-
-    val groupingAttributes = groupingExpressions.map(_.toAttribute)
-    val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
-
-    aggregationMode match {
-      // Partial-only or PartialMerge-only: every output row is basically the values of
-      // the grouping expressions and the corresponding aggregation buffer.
-      case (Some(Partial), None) | (Some(PartialMerge), None) =>
-        val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
-        val bufferSchema = StructType.fromAttributes(bufferAttributes)
-        val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
-
-        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
-          unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
-        }
-
-      // Final-only, Complete-only and Final-Complete: a output row is generated based on
-      // resultExpressions.
-      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
-        val joinedRow = new JoinedRow()
-        val evalExpressions = allAggregateFunctions.map {
-          case ae: DeclarativeAggregate => ae.evaluateExpression
-          case agg: AggregateFunction => NoOp
-        }
-        val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
-        // These are the attributes of the row produced by `expressionAggEvalProjection`
-        val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
-        val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
-        expressionAggEvalProjection.target(aggregateResult)
-        val resultProjection =
-          UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
-
-        val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
-          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
-
-        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
-          // Generate results for all expression-based aggregate functions.
-          expressionAggEvalProjection(currentBuffer)
-          // Generate results for all imperative aggregate functions.
-          var i = 0
-          while (i < allImperativeAggregateFunctions.length) {
-            aggregateResult.update(
-              allImperativeAggregateFunctionPositions(i),
-              allImperativeAggregateFunctions(i).eval(currentBuffer))
-            i += 1
-          }
-          resultProjection(joinedRow(currentGroupingKey, aggregateResult))
-        }
-
-      // Grouping-only: a output row is generated from values of grouping expressions.
-      case (None, None) =>
-        val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
-
-        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
-          resultProjection(currentGroupingKey)
-        }
-
-      case other =>
-        throw new IllegalStateException(
-          s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+  override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = {
+    val modes = aggregateExpressions.map(_.mode).distinct
+    if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) {
+      // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection
+      val groupingAttributes = groupingExpressions.map(_.toAttribute)
+      val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+      val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+      val bufferSchema = StructType.fromAttributes(bufferAttributes)
+      val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
+      (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+        unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow])
+      }
+    } else {
+      super.generateResultProjection()
     }
   }
 
-  // An UnsafeProjection used to extract grouping keys from the input rows.
-  private[this] val groupProjection =
-    UnsafeProjection.create(groupingExpressions, originalInputAttributes)
-
-  // A function used to process a input row. Its first argument is the aggregation buffer
-  // and the second argument is the input row.
-  private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
-    generateProcessRow(originalInputAttributes)
-
-  // A function used to generate output rows based on the grouping keys (first argument)
-  // and the corresponding aggregation buffer (second argument).
-  private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
-    generateResultProjection()
-
   // An aggregation buffer containing initial buffer values. It is used to
   // initialize other aggregation buffers.
   private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
@@ -482,7 +158,7 @@ class TungstenAggregationIterator(
   // all groups and their corresponding aggregation buffers for hash-based aggregation.
   private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
     initialAggregationBuffer,
-    StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
+    StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
     StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
     TaskContext.get().taskMemoryManager(),
     1024 * 16, // initial capacity
@@ -499,7 +175,7 @@ class TungstenAggregationIterator(
     if (groupingExpressions.isEmpty) {
       // If there is no grouping expressions, we can just reuse the same buffer over and over again.
       // Note that it would be better to eliminate the hash map entirely in the future.
-      val groupingKey = groupProjection.apply(null)
+      val groupingKey = groupingProjection.apply(null)
       val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
       while (inputIter.hasNext) {
         val newInput = inputIter.next()
@@ -511,7 +187,7 @@ class TungstenAggregationIterator(
       while (inputIter.hasNext) {
         val newInput = inputIter.next()
         numInputRows += 1
-        val groupingKey = groupProjection.apply(newInput)
+        val groupingKey = groupingProjection.apply(newInput)
         var buffer: UnsafeRow = null
         if (i < fallbackStartsAt) {
           buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
@@ -565,25 +241,18 @@ class TungstenAggregationIterator(
   private def switchToSortBasedAggregation(): Unit = {
     logInfo("falling back to sort based aggregation.")
 
-    // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
-    val newAggregationMode = aggregationMode match {
-      case (Some(Partial), None) => (Some(PartialMerge), None)
-      case (None, Some(Complete)) => (Some(Final), None)
-      case (Some(Final), Some(Complete)) => (Some(Final), None)
+    // Basically the value of the KVIterator returned by externalSorter
+    // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
+    val newExpressions = aggregateExpressions.map {
+      case agg @ AggregateExpression(_, Partial, _) =>
+        agg.copy(mode = PartialMerge)
+      case agg @ AggregateExpression(_, Complete, _) =>
+        agg.copy(mode = Final)
       case other => other
     }
-    aggregationMode = newAggregationMode
-
-    allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0)
-
-    // Basically the value of the KVIterator returned by externalSorter
-    // will just aggregation buffer. At here, we use inputAggBufferAttributes.
-    val newInputAttributes: Seq[Attribute] =
-      allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
-
-    // Set up new processRow and generateOutput.
-    processRow = generateProcessRow(newInputAttributes)
-    generateOutput = generateResultProjection()
+    val newFunctions = initializeAggregateFunctions(newExpressions, 0)
+    val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
+    sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes)
 
     // Step 5: Get the sorted iterator from the externalSorter.
     sortedKVIterator = externalSorter.sortedIterator()
@@ -632,6 +301,9 @@ class TungstenAggregationIterator(
   // The aggregation buffer used by the sort-based aggregation.
   private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
 
+  // The function used to process rows in a group
+  private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null
+
   // Processes rows in the current group. It will stop when it find a new group.
   private def processCurrentSortedGroup(): Unit = {
     // First, we need to copy nextGroupingKey to currentGroupingKey.
@@ -640,7 +312,7 @@ class TungstenAggregationIterator(
     // We create a variable to track if we see the next group.
     var findNextPartition = false
     // firstRowInNextGroup is the first row of this group. We first process it.
-    processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+    sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup)
 
     // The search will stop when we see the next group or there is no
     // input row left in the iter.
@@ -655,16 +327,15 @@ class TungstenAggregationIterator(
 
       // Check if the current row belongs the current input row.
       if (currentGroupingKey.equals(groupingKey)) {
-        processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+        sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer)
 
         hasNext = sortedKVIterator.next()
       } else {
         // We find a new group.
         findNextPartition = true
         // copyFrom will fail when
-        nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
-        firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy()
-
+        nextGroupingKey.copyFrom(groupingKey)
+        firstRowInNextGroup.copyFrom(inputAggregationBuffer)
       }
     }
     // We have not seen a new group. It means that there is no new row in the input


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org