You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/10/15 02:28:04 UTC

spark git commit: [SPARK-11017] [SQL] Support ImperativeAggregates in TungstenAggregate

Repository: spark
Updated Branches:
  refs/heads/master 1baaf2b9b -> 4ace4f8a9


[SPARK-11017] [SQL] Support ImperativeAggregates in TungstenAggregate

This patch extends TungstenAggregate to support ImperativeAggregate functions. The existing TungstenAggregate operator only supported DeclarativeAggregate functions, which are defined in terms of Catalyst expressions and can be evaluated via generated projections. ImperativeAggregate functions, on the other hand, are evaluated by calling their `initialize`, `update`, `merge`, and `eval` methods.

The basic strategy here is similar to how SortBasedAggregate evaluates both types of aggregate functions: use a generated projection to evaluate the expression-based declarative aggregates with dummy placeholder expressions inserted in place of the imperative aggregate function output, then invoke the imperative aggregate functions and target them against the aggregation buffer. The bulk of the diff here consists of code that was copied and adapted from SortBasedAggregate, with some key changes to handle TungstenAggregate's sort fallback path.

Author: Josh Rosen <jo...@databricks.com>

Closes #9038 from JoshRosen/support-interpreted-in-tungsten-agg-final.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ace4f8a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ace4f8a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ace4f8a

Branch: refs/heads/master
Commit: 4ace4f8a9c91beb21a0077e12b75637a4560a542
Parents: 1baaf2b
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Oct 14 17:27:27 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed Oct 14 17:27:50 2015 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/functions.scala       |  19 +-
 .../expressions/aggregate/interfaces.scala      |  31 ++-
 .../aggregate/AggregationIterator.scala         |  29 +-
 .../execution/aggregate/TungstenAggregate.scala |  22 +-
 .../aggregate/TungstenAggregationIterator.scala | 250 +++++++++++++----
 .../spark/sql/execution/aggregate/udaf.scala    |  79 +++---
 .../spark/sql/execution/aggregate/utils.scala   | 269 +++++++++----------
 .../TungstenAggregationIteratorSuite.scala      |   2 +-
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  16 +-
 9 files changed, 457 insertions(+), 260 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 8aad0b7..c0bc7ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -472,10 +472,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
  * @param relativeSD the maximum estimation error allowed.
  */
 // scalastyle:on
-case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
-    extends ImperativeAggregate {
+case class HyperLogLogPlusPlus(
+    child: Expression,
+    relativeSD: Double = 0.05,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends ImperativeAggregate {
   import HyperLogLogPlusPlus._
 
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
   /**
    * HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the
    * algorithm will be, and the more memory it will require. The 'p' value is based on the relative
@@ -546,6 +556,11 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
     AttributeReference(s"MS[$i]", LongType)()
   }
 
+  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+  // in the superclass because that will lead to initialization ordering issues.
+  override val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
   /** Fill all words with zeros. */
   override def initialize(buffer: MutableRow): Unit = {
     var word = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9ba3a9c..a2fab25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -150,6 +150,10 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
  * We need to perform similar field number arithmetic when merging multiple intermediate
  * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing
  * the input buffer).
+ *
+ * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and
+ * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
+ * and `inputAggBufferAttributes`.
  */
 abstract class ImperativeAggregate extends AggregateFunction2 {
 
@@ -172,11 +176,13 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
    *                     avg(y) mutableAggBufferOffset = 2
    *
    */
-  protected var mutableAggBufferOffset: Int = 0
+  protected val mutableAggBufferOffset: Int
 
-  def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = {
-    mutableAggBufferOffset = newMutableAggBufferOffset
-  }
+  /**
+   * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+   * This new copy's attributes may have different ids than the original.
+   */
+  def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate
 
   /**
    * The offset of this function's start buffer value in the underlying shared input aggregation
@@ -203,11 +209,17 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
    *                       avg(y) inputAggBufferOffset = 3
    *
    */
-  protected var inputAggBufferOffset: Int = 0
+  protected val inputAggBufferOffset: Int
 
-  def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = {
-    inputAggBufferOffset = newInputAggBufferOffset
-  }
+  /**
+   * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+   * This new copy's attributes may have different ids than the original.
+   */
+  def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
+
+  // Note: although all subclasses implement inputAggBufferAttributes by simply cloning
+  // aggBufferAttributes, that common clone code cannot be placed here in the abstract
+  // ImperativeAggregate class, since that will lead to initialization ordering issues.
 
   /**
    * Initializes the mutable aggregation buffer located in `mutableAggBuffer`.
@@ -231,9 +243,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
    * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
    */
   def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
-
-  final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
-    aggBufferAttributes.map(_.newInstance())
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 8e0fbd1..99fb7a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -83,7 +83,7 @@ abstract class AggregationIterator(
     var i = 0
     while (i < allAggregateExpressions.length) {
       val func = allAggregateExpressions(i).aggregateFunction
-      val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+      val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match {
         case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
           // We need to create BoundReferences if the function is not an
           // expression-based aggregate function (it does not support code-gen) and the mode of
@@ -94,24 +94,24 @@ abstract class AggregationIterator(
         case _ =>
           // We only need to set inputBufferOffset for aggregate functions with mode
           // PartialMerge and Final.
-          func match {
+          val updatedFunc = func match {
             case function: ImperativeAggregate =>
               function.withNewInputAggBufferOffset(inputBufferOffset)
-            case _ =>
+            case function => function
           }
           inputBufferOffset += func.aggBufferSchema.length
-          func
+          updatedFunc
       }
-      // Set mutableBufferOffset for this function. It is important that setting
-      // mutableBufferOffset happens after all potential bindReference operations
-      // because bindReference will create a new instance of the function.
-      funcWithBoundReferences match {
+      val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
         case function: ImperativeAggregate =>
+          // Set mutableBufferOffset for this function. It is important that setting
+          // mutableBufferOffset happens after all potential bindReference operations
+          // because bindReference will create a new instance of the function.
           function.withNewMutableAggBufferOffset(mutableBufferOffset)
-        case _ =>
+        case function => function
       }
-      mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length
-      functions(i) = funcWithBoundReferences
+      mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+      functions(i) = funcWithUpdatedAggBufferOffset
       i += 1
     }
     functions
@@ -320,7 +320,7 @@ abstract class AggregationIterator(
   // Initializing the function used to generate the output row.
   protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
     val rowToBeEvaluated = new JoinedRow
-    val safeOutputRow = new GenericMutableRow(resultExpressions.length)
+    val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
     val mutableOutput = if (outputsUnsafeRows) {
       UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
     } else {
@@ -358,7 +358,8 @@ abstract class AggregationIterator(
         val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
         val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
         // TODO: Use unsafe row.
-        val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+        val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+        expressionAggEvalProjection.target(aggregateResult)
         val resultProjection =
           newMutableProjection(
             resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
@@ -366,7 +367,7 @@ abstract class AggregationIterator(
 
         (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
           // Generate results for all expression-based aggregate functions.
-          expressionAggEvalProjection.target(aggregateResult)(currentBuffer)
+          expressionAggEvalProjection(currentBuffer)
           // Generate results for all imperative aggregate functions.
           var i = 0
           while (i < allImperativeAggregateFunctions.length) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 7b3d072..c342940 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.StructType
 
 case class TungstenAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -34,10 +35,18 @@ case class TungstenAggregate(
     nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
     completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
   extends UnaryNode {
 
+  private[this] val aggregateBufferAttributes = {
+    (nonCompleteAggregateExpressions ++ completeAggregateExpressions)
+      .flatMap(_.aggregateFunction.aggBufferAttributes)
+  }
+
+  require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
+
   override private[sql] lazy val metrics = Map(
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -82,6 +91,7 @@ case class TungstenAggregate(
         nonCompleteAggregateAttributes,
         completeAggregateExpressions,
         completeAggregateAttributes,
+        initialInputBufferOffset,
         resultExpressions,
         newMutableProjection,
         child.output,
@@ -138,3 +148,13 @@ case class TungstenAggregate(
     }
   }
 }
+
+object TungstenAggregate {
+  def supportsAggregate(
+    groupingExpressions: Seq[Expression],
+    aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+    val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+    UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+      UnsafeProjection.canSupport(groupingExpressions)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 4bb95c9..fe708a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.unsafe.KVIterator
 import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.expressions._
@@ -79,6 +81,7 @@ class TungstenAggregationIterator(
     nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
     completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
     originalInputAttributes: Seq[Attribute],
@@ -134,19 +137,74 @@ class TungstenAggregationIterator(
       completeAggregateExpressions.map(_.mode).distinct.headOption
   }
 
-  // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate.
-  // If there is any functions that is an ImperativeAggregateFunction, we throw an
-  // IllegalStateException.
-  private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = {
-    if (!allAggregateExpressions.forall(
-        _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) {
-      throw new IllegalStateException(
-        "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.")
+  // Initialize all AggregateFunctions by binding references, if necessary,
+  // and setting inputBufferOffset and mutableBufferOffset.
+  private def initializeAllAggregateFunctions(
+      startingInputBufferOffset: Int): Array[AggregateFunction2] = {
+    var mutableBufferOffset = 0
+    var inputBufferOffset: Int = startingInputBufferOffset
+    val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+    var i = 0
+    while (i < allAggregateExpressions.length) {
+      val func = allAggregateExpressions(i).aggregateFunction
+      val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length
+      // We need to use this mode instead of func.mode in order to handle aggregation mode switching
+      // when switching to sort-based aggregation:
+      val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2
+      val funcWithBoundReferences = mode match {
+        case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] =>
+          // We need to create BoundReferences if the function is not an
+          // expression-based aggregate function (it does not support code-gen) and the mode of
+          // this function is Partial or Complete because we will call eval of this
+          // function's children in the update method of this aggregate function.
+          // Those eval calls require BoundReferences to work.
+          BindReferences.bindReference(func, originalInputAttributes)
+        case _ =>
+          // We only need to set inputBufferOffset for aggregate functions with mode
+          // PartialMerge and Final.
+          val updatedFunc = func match {
+            case function: ImperativeAggregate =>
+              function.withNewInputAggBufferOffset(inputBufferOffset)
+            case function => function
+          }
+          inputBufferOffset += func.aggBufferSchema.length
+          updatedFunc
+      }
+      val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
+        case function: ImperativeAggregate =>
+          // Set mutableBufferOffset for this function. It is important that setting
+          // mutableBufferOffset happens after all potential bindReference operations
+          // because bindReference will create a new instance of the function.
+          function.withNewMutableAggBufferOffset(mutableBufferOffset)
+        case function => function
+      }
+      mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+      functions(i) = funcWithUpdatedAggBufferOffset
+      i += 1
     }
+    functions
+  }
 
-    allAggregateExpressions
-      .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
-      .toArray
+  private[this] var allAggregateFunctions: Array[AggregateFunction2] =
+    initializeAllAggregateFunctions(initialInputBufferOffset)
+
+  // Positions of those imperative aggregate functions in allAggregateFunctions.
+  // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and
+  // func2 and func3 are imperative aggregate functions. Then
+  // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be
+  // updated when falling back to sort-based aggregation because the positions of the aggregate
+  // functions do not change in that case.
+  private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
+    val positions = new ArrayBuffer[Int]()
+    var i = 0
+    while (i < allAggregateFunctions.length) {
+      allAggregateFunctions(i) match {
+        case agg: DeclarativeAggregate =>
+        case _ => positions += i
+      }
+      i += 1
+    }
+    positions.toArray
   }
 
   ///////////////////////////////////////////////////////////////////////////
@@ -155,25 +213,31 @@ class TungstenAggregationIterator(
   //         rows.
   ///////////////////////////////////////////////////////////////////////////
 
-  // The projection used to initialize buffer values.
-  private[this] val initialProjection: MutableProjection = {
-    val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+  // The projection used to initialize buffer values for all expression-based aggregates.
+  // Note that this projection does not need to be updated when switching to sort-based aggregation
+  // because the schema of empty aggregation buffers does not change in that case.
+  private[this] val expressionAggInitialProjection: MutableProjection = {
+    val initExpressions = allAggregateFunctions.flatMap {
+      case ae: DeclarativeAggregate => ae.initialValues
+      // For the positions corresponding to imperative aggregate functions, we'll use special
+      // no-op expressions which are ignored during projection code-generation.
+      case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
+    }
     newMutableProjection(initExpressions, Nil)()
   }
 
   // Creates a new aggregation buffer and initializes buffer values.
-  // This functions should be only called at most three times (when we create the hash map,
+  // This function should be only called at most three times (when we create the hash map,
   // when we switch to sort-based aggregation, and when we create the re-used buffer for
   // sort-based aggregation).
   private def createNewAggregationBuffer(): UnsafeRow = {
     val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
-    val bufferRowSize: Int = bufferSchema.length
-
-    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
-    val unsafeProjection =
-      UnsafeProjection.create(bufferSchema.map(_.dataType))
-    val buffer = unsafeProjection.apply(genericMutableBuffer)
-    initialProjection.target(buffer)(EmptyRow)
+    val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType))
+      .apply(new GenericMutableRow(bufferSchema.length))
+    // Initialize declarative aggregates' buffer values
+    expressionAggInitialProjection.target(buffer)(EmptyRow)
+    // Initialize imperative aggregates' buffer values
+    allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
     buffer
   }
 
@@ -187,72 +251,124 @@ class TungstenAggregationIterator(
     aggregationMode match {
       // Partial-only
       case (Some(Partial), None) =>
-        val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
-        val updateProjection =
+        val updateExpressions = allAggregateFunctions.flatMap {
+          case ae: DeclarativeAggregate => ae.updateExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+        }
+        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+        val expressionAggUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
 
         (currentBuffer: UnsafeRow, row: InternalRow) => {
-          updateProjection.target(currentBuffer)
-          updateProjection(joinedRow(currentBuffer, row))
+          expressionAggUpdateProjection.target(currentBuffer)
+          // Process all expression-based aggregate functions.
+          expressionAggUpdateProjection(joinedRow(currentBuffer, row))
+          // Process all imperative aggregate functions
+          var i = 0
+          while (i < imperativeAggregateFunctions.length) {
+            imperativeAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
         }
 
       // PartialMerge-only or Final-only
       case (Some(PartialMerge), None) | (Some(Final), None) =>
-        val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
-        val mergeProjection =
+        val mergeExpressions = allAggregateFunctions.flatMap {
+          case ae: DeclarativeAggregate => ae.mergeExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+        }
+        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+        // This projection is used to merge buffer values for all expression-based aggregates.
+        val expressionAggMergeProjection =
           newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
 
         (currentBuffer: UnsafeRow, row: InternalRow) => {
-          mergeProjection.target(currentBuffer)
-          mergeProjection(joinedRow(currentBuffer, row))
+          // Process all expression-based aggregate functions.
+          expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+          // Process all imperative aggregate functions.
+          var i = 0
+          while (i < imperativeAggregateFunctions.length) {
+            imperativeAggregateFunctions(i).merge(currentBuffer, row)
+            i += 1
+          }
         }
 
       // Final-Complete
       case (Some(Final), Some(Complete)) =>
-        val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] =
-          allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
-        val completeAggregateFunctions: Array[DeclarativeAggregate] =
+        val completeAggregateFunctions: Array[AggregateFunction2] =
           allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
+        val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+          allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+        val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
+          nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
 
         val completeOffsetExpressions =
           Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
         val mergeExpressions =
-          nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+          nonCompleteAggregateFunctions.flatMap {
+            case ae: DeclarativeAggregate => ae.mergeExpressions
+            case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+          } ++ completeOffsetExpressions
         val finalMergeProjection =
           newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
 
         // We do not touch buffer values of aggregate functions with the Final mode.
         val finalOffsetExpressions =
           Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-        val updateExpressions =
-          finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+        val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+          case ae: DeclarativeAggregate => ae.updateExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+        }
         val completeUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
 
         (currentBuffer: UnsafeRow, row: InternalRow) => {
           val input = joinedRow(currentBuffer, row)
-          // For all aggregate functions with mode Complete, update the given currentBuffer.
+          // For all aggregate functions with mode Complete, update buffers.
           completeUpdateProjection.target(currentBuffer)(input)
+          var i = 0
+          while (i < completeImperativeAggregateFunctions.length) {
+            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
 
           // For all aggregate functions with mode Final, merge buffer values in row to
           // currentBuffer.
           finalMergeProjection.target(currentBuffer)(input)
+          i = 0
+          while (i < nonCompleteImperativeAggregateFunctions.length) {
+            nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
+            i += 1
+          }
         }
 
       // Complete-only
       case (None, Some(Complete)) =>
-        val completeAggregateFunctions: Array[DeclarativeAggregate] =
+        val completeAggregateFunctions: Array[AggregateFunction2] =
           allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+        // All imperative aggregate functions with mode Complete.
+        val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+          completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
 
-        val updateExpressions =
-          completeAggregateFunctions.flatMap(_.updateExpressions)
-        val completeUpdateProjection =
+        val updateExpressions = completeAggregateFunctions.flatMap {
+          case ae: DeclarativeAggregate => ae.updateExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+        }
+        val completeExpressionAggUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
 
         (currentBuffer: UnsafeRow, row: InternalRow) => {
-          completeUpdateProjection.target(currentBuffer)
-          // For all aggregate functions with mode Complete, update the given currentBuffer.
-          completeUpdateProjection(joinedRow(currentBuffer, row))
+          // For all aggregate functions with mode Complete, update buffers.
+          completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+          var i = 0
+          while (i < completeImperativeAggregateFunctions.length) {
+            completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
         }
 
       // Grouping only.
@@ -288,17 +404,30 @@ class TungstenAggregationIterator(
         val joinedRow = new JoinedRow()
         val evalExpressions = allAggregateFunctions.map {
           case ae: DeclarativeAggregate => ae.evaluateExpression
-          // case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
+          case agg: AggregateFunction2 => NoOp
         }
-        val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
+        val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
         // These are the attributes of the row produced by `expressionAggEvalProjection`
         val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+        val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+        expressionAggEvalProjection.target(aggregateResult)
         val resultProjection =
           UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
 
+        val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+          allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+
         (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
           // Generate results for all expression-based aggregate functions.
-          val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
+          expressionAggEvalProjection(currentBuffer)
+          // Generate results for all imperative aggregate functions.
+          var i = 0
+          while (i < allImperativeAggregateFunctions.length) {
+            aggregateResult.update(
+              allImperativeAggregateFunctionPositions(i),
+              allImperativeAggregateFunctions(i).eval(currentBuffer))
+            i += 1
+          }
           resultProjection(joinedRow(currentGroupingKey, aggregateResult))
         }
 
@@ -481,10 +610,27 @@ class TungstenAggregationIterator(
       // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
       // We need to project the aggregation buffer part from an input row.
       val buffer = createNewAggregationBuffer()
-      // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
-      // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+      // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to
+      // extract the aggregation buffer. In practice, however, we extract it positionally by relying
+      // on it being present at the end of the row. The reason for this relates to how the different
+      // aggregates handle input binding.
+      //
+      // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers,
+      // so its correctness does not rely on attribute bindings. When we fall back to sort-based
+      // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset)
+      // need to be updated and any internal state in the aggregate functions themselves must be
+      // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset
+      // this state and update the offsets.
+      //
+      // The updated ImperativeAggregate will have different attribute ids for its
+      // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual
+      // ImperativeAggregate evaluation, but it means that
+      // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the
+      // attributes in `originalInputAttributes`, which is why we can't use those attributes here.
+      //
+      // For more details, see the discussion on PR #9038.
       val bufferExtractor = newMutableProjection(
-        allAggregateFunctions.flatMap(_.inputAggBufferAttributes),
+        originalInputAttributes.drop(initialInputBufferOffset),
         originalInputAttributes)()
       bufferExtractor.target(buffer)
 
@@ -511,8 +657,10 @@ class TungstenAggregationIterator(
     }
     aggregationMode = newAggregationMode
 
+    allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0)
+
     // Basically the value of the KVIterator returned by externalSorter
-    // will just aggregation buffer. At here, we use cloneBufferAttributes.
+    // will just aggregation buffer. At here, we use inputAggBufferAttributes.
     val newInputAttributes: Seq[Attribute] =
       allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index fd02be1..d2f56e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -321,9 +321,17 @@ private[sql] class InputAggregationBuffer private[sql] (
  */
 private[sql] case class ScalaUDAF(
     children: Seq[Expression],
-    udaf: UserDefinedAggregateFunction)
+    udaf: UserDefinedAggregateFunction,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
   extends ImperativeAggregate with Logging {
 
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
   require(
     children.length == udaf.inputSchema.length,
     s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
@@ -341,6 +349,11 @@ private[sql] case class ScalaUDAF(
 
   override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
 
+  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+  // in the superclass because that will lead to initialization ordering issues.
+  override val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
   private[this] lazy val childrenSchema: StructType = {
     val inputFields = children.zipWithIndex.map {
       case (child, index) =>
@@ -382,51 +395,33 @@ private[sql] case class ScalaUDAF(
   }
 
   // This buffer is only used at executor side.
-  private[this] var inputAggregateBuffer: InputAggregationBuffer = null
-
-  // This buffer is only used at executor side.
-  private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+  private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = {
+    new InputAggregationBuffer(
+      aggBufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      inputAggBufferOffset,
+      null)
+  }
 
   // This buffer is only used at executor side.
-  private[this] var evalAggregateBuffer: InputAggregationBuffer = null
-
-  /**
-   * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
-   * `inputAggregateBuffer` based on this new inputBufferOffset.
-   */
-  override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = {
-    super.withNewInputAggBufferOffset(newInputBufferOffset)
-    // inputBufferOffset has been updated.
-    inputAggregateBuffer =
-      new InputAggregationBuffer(
-        aggBufferSchema,
-        bufferValuesToCatalystConverters,
-        bufferValuesToScalaConverters,
-        inputAggBufferOffset,
-        null)
+  private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = {
+    new MutableAggregationBufferImpl(
+      aggBufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      mutableAggBufferOffset,
+      null)
   }
 
-  /**
-   * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
-   * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
-   */
-  override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = {
-    super.withNewMutableAggBufferOffset(newMutableBufferOffset)
-    // mutableBufferOffset has been updated.
-    mutableAggregateBuffer =
-      new MutableAggregationBufferImpl(
-        aggBufferSchema,
-        bufferValuesToCatalystConverters,
-        bufferValuesToScalaConverters,
-        mutableAggBufferOffset,
-        null)
-    evalAggregateBuffer =
-      new InputAggregationBuffer(
-        aggBufferSchema,
-        bufferValuesToCatalystConverters,
-        bufferValuesToScalaConverters,
-        mutableAggBufferOffset,
-        null)
+  // This buffer is only used at executor side.
+  private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = {
+    new InputAggregationBuffer(
+      aggBufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      mutableAggBufferOffset,
+      null)
   }
 
   override def initialize(buffer: MutableRow): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index cf6e7ed..eaafd83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -19,21 +19,12 @@ package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.SparkPlan
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
  */
 object Utils {
-  def supportsTungstenAggregate(
-      groupingExpressions: Seq[Expression],
-      aggregateBufferAttributes: Seq[Attribute]): Boolean = {
-    val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
-
-    UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
-      UnsafeProjection.canSupport(groupingExpressions)
-  }
 
   def planAggregateWithoutPartial(
       groupingExpressions: Seq[NamedExpression],
@@ -70,8 +61,7 @@ object Utils {
     // Check if we can use TungstenAggregate.
     val usesTungstenAggregate =
       child.sqlContext.conf.unsafeEnabled &&
-      aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
-      supportsTungstenAggregate(
+      TungstenAggregate.supportsAggregate(
         groupingExpressions,
         aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
 
@@ -94,6 +84,7 @@ object Utils {
         nonCompleteAggregateAttributes = partialAggregateAttributes,
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
+        initialInputBufferOffset = 0,
         resultExpressions = partialResultExpressions,
         child = child)
     } else {
@@ -125,6 +116,7 @@ object Utils {
         nonCompleteAggregateAttributes = finalAggregateAttributes,
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
+        initialInputBufferOffset = groupingExpressions.length,
         resultExpressions = resultExpressions,
         child = partialAggregate)
     } else {
@@ -154,143 +146,150 @@ object Utils {
     val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
     val usesTungstenAggregate =
       child.sqlContext.conf.unsafeEnabled &&
-        aggregateExpressions.forall(
-          _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
-        supportsTungstenAggregate(
+        TungstenAggregate.supportsAggregate(
           groupingExpressions,
           aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
 
-    // 1. Create an Aggregate Operator for partial aggregations.
-    val groupingAttributes = groupingExpressions.map(_.toAttribute)
-
-    // It is safe to call head at here since functionsWithDistinct has at least one
-    // AggregateExpression2.
-    val distinctColumnExpressions =
-      functionsWithDistinct.head.aggregateFunction.children
-    val namedDistinctColumnExpressions = distinctColumnExpressions.map {
-      case ne: NamedExpression => ne -> ne
-      case other =>
-        val withAlias = Alias(other, other.toString)()
-        other -> withAlias
+    // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
+    // DISTINCT aggregate function, all of those functions will have the same column expression.
+    // For example, it would be valid for functionsWithDistinct to be
+    // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
+    // disallowed because those two distinct aggregates have different column expressions.
+    val distinctColumnExpression: Expression = {
+      val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+      assert(allDistinctColumnExpressions.length == 1)
+      allDistinctColumnExpressions.head
+    }
+    val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
     }
-    val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
-    val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+    val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
-    val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
-    val partialAggregateAttributes =
-      partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-    val partialAggregateGroupingExpressions =
-      groupingExpressions ++ namedDistinctColumnExpressions.map(_._2)
-    val partialAggregateResult =
+    // 1. Create an Aggregate Operator for partial aggregations.
+    val partialAggregate: SparkPlan = {
+      val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+      val partialAggregateAttributes =
+        partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+      // We will group by the original grouping expression, plus an additional expression for the
+      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+      // expressions will be [key, value].
+      val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+      val partialAggregateResult =
         groupingAttributes ++
-        distinctColumnAttributes ++
-        partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
-    val partialAggregate = if (usesTungstenAggregate) {
-      TungstenAggregate(
-        requiredChildDistributionExpressions = None,
-        // The grouping expressions are original groupingExpressions and
-        // distinct columns. For example, for avg(distinct value) ... group by key
-        // the grouping expressions of this Aggregate Operator will be [key, value].
-        groupingExpressions = partialAggregateGroupingExpressions,
-        nonCompleteAggregateExpressions = partialAggregateExpressions,
-        nonCompleteAggregateAttributes = partialAggregateAttributes,
-        completeAggregateExpressions = Nil,
-        completeAggregateAttributes = Nil,
-        resultExpressions = partialAggregateResult,
-        child = child)
-    } else {
-      SortBasedAggregate(
-        requiredChildDistributionExpressions = None,
-        groupingExpressions = partialAggregateGroupingExpressions,
-        nonCompleteAggregateExpressions = partialAggregateExpressions,
-        nonCompleteAggregateAttributes = partialAggregateAttributes,
-        completeAggregateExpressions = Nil,
-        completeAggregateAttributes = Nil,
-        initialInputBufferOffset = 0,
-        resultExpressions = partialAggregateResult,
-        child = child)
+          Seq(distinctColumnAttribute) ++
+          partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+      if (usesTungstenAggregate) {
+        TungstenAggregate(
+          requiredChildDistributionExpressions = None,
+          groupingExpressions = partialAggregateGroupingExpressions,
+          nonCompleteAggregateExpressions = partialAggregateExpressions,
+          nonCompleteAggregateAttributes = partialAggregateAttributes,
+          completeAggregateExpressions = Nil,
+          completeAggregateAttributes = Nil,
+          initialInputBufferOffset = 0,
+          resultExpressions = partialAggregateResult,
+          child = child)
+      } else {
+        SortBasedAggregate(
+          requiredChildDistributionExpressions = None,
+          groupingExpressions = partialAggregateGroupingExpressions,
+          nonCompleteAggregateExpressions = partialAggregateExpressions,
+          nonCompleteAggregateAttributes = partialAggregateAttributes,
+          completeAggregateExpressions = Nil,
+          completeAggregateAttributes = Nil,
+          initialInputBufferOffset = 0,
+          resultExpressions = partialAggregateResult,
+          child = child)
+      }
     }
 
     // 2. Create an Aggregate Operator for partial merge aggregations.
-    val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
-    val partialMergeAggregateAttributes =
-      partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-    val partialMergeAggregateResult =
+    val partialMergeAggregate: SparkPlan = {
+      val partialMergeAggregateExpressions =
+        functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      val partialMergeAggregateAttributes =
+        partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+      val partialMergeAggregateResult =
         groupingAttributes ++
-        distinctColumnAttributes ++
-        partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
-    val partialMergeAggregate = if (usesTungstenAggregate) {
-      TungstenAggregate(
-        requiredChildDistributionExpressions = Some(groupingAttributes),
-        groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
-        nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
-        nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
-        completeAggregateExpressions = Nil,
-        completeAggregateAttributes = Nil,
-        resultExpressions = partialMergeAggregateResult,
-        child = partialAggregate)
-    } else {
-      SortBasedAggregate(
-        requiredChildDistributionExpressions = Some(groupingAttributes),
-        groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
-        nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
-        nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
-        completeAggregateExpressions = Nil,
-        completeAggregateAttributes = Nil,
-        initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
-        resultExpressions = partialMergeAggregateResult,
-        child = partialAggregate)
+          Seq(distinctColumnAttribute) ++
+          partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+      if (usesTungstenAggregate) {
+        TungstenAggregate(
+          requiredChildDistributionExpressions = Some(groupingAttributes),
+          groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+          nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+          nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+          completeAggregateExpressions = Nil,
+          completeAggregateAttributes = Nil,
+          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          resultExpressions = partialMergeAggregateResult,
+          child = partialAggregate)
+      } else {
+        SortBasedAggregate(
+          requiredChildDistributionExpressions = Some(groupingAttributes),
+          groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+          nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+          nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+          completeAggregateExpressions = Nil,
+          completeAggregateAttributes = Nil,
+          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          resultExpressions = partialMergeAggregateResult,
+          child = partialAggregate)
+      }
     }
 
-    // 3. Create an Aggregate Operator for partial merge aggregations.
-    val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
-    // The attributes of the final aggregation buffer, which is presented as input to the result
-    // projection:
-    val finalAggregateAttributes = finalAggregateExpressions.map {
-      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-    }
+    // 3. Create an Aggregate Operator for the final aggregation.
+    val finalAndCompleteAggregate: SparkPlan = {
+      val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+      // The attributes of the final aggregation buffer, which is presented as input to the result
+      // projection:
+      val finalAggregateAttributes = finalAggregateExpressions.map {
+        expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+      }
 
-    val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
-      // Children of an AggregateFunction with DISTINCT keyword has already
-      // been evaluated. At here, we need to replace original children
-      // to AttributeReferences.
-      case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
-        val rewrittenAggregateFunction = aggregateFunction.transformDown {
-          case expr if distinctColumnExpressionMap.contains(expr) =>
-            distinctColumnExpressionMap(expr).toAttribute
-        }.asInstanceOf[AggregateFunction2]
-        // We rewrite the aggregate function to a non-distinct aggregation because
-        // its input will have distinct arguments.
-        // We just keep the isDistinct setting to true, so when users look at the query plan,
-        // they still can see distinct aggregations.
-        val rewrittenAggregateExpression =
-          AggregateExpression2(rewrittenAggregateFunction, Complete, true)
+      val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+        // Children of an AggregateFunction with DISTINCT keyword has already
+        // been evaluated. At here, we need to replace original children
+        // to AttributeReferences.
+        case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
+          val rewrittenAggregateFunction = aggregateFunction.transformDown {
+            case expr if expr == distinctColumnExpression => distinctColumnAttribute
+          }.asInstanceOf[AggregateFunction2]
+          // We rewrite the aggregate function to a non-distinct aggregation because
+          // its input will have distinct arguments.
+          // We just keep the isDistinct setting to true, so when users look at the query plan,
+          // they still can see distinct aggregations.
+          val rewrittenAggregateExpression =
+            AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true)
 
-        val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
-        (rewrittenAggregateExpression, aggregateFunctionAttribute)
-    }.unzip
-
-    val finalAndCompleteAggregate = if (usesTungstenAggregate) {
-      TungstenAggregate(
-        requiredChildDistributionExpressions = Some(groupingAttributes),
-        groupingExpressions = groupingAttributes,
-        nonCompleteAggregateExpressions = finalAggregateExpressions,
-        nonCompleteAggregateAttributes = finalAggregateAttributes,
-        completeAggregateExpressions = completeAggregateExpressions,
-        completeAggregateAttributes = completeAggregateAttributes,
-        resultExpressions = resultExpressions,
-        child = partialMergeAggregate)
-    } else {
-      SortBasedAggregate(
-        requiredChildDistributionExpressions = Some(groupingAttributes),
-        groupingExpressions = groupingAttributes,
-        nonCompleteAggregateExpressions = finalAggregateExpressions,
-        nonCompleteAggregateAttributes = finalAggregateAttributes,
-        completeAggregateExpressions = completeAggregateExpressions,
-        completeAggregateAttributes = completeAggregateAttributes,
-        initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
-        resultExpressions = resultExpressions,
-        child = partialMergeAggregate)
+          val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
+          (rewrittenAggregateExpression, aggregateFunctionAttribute)
+      }.unzip
+      if (usesTungstenAggregate) {
+        TungstenAggregate(
+          requiredChildDistributionExpressions = Some(groupingAttributes),
+          groupingExpressions = groupingAttributes,
+          nonCompleteAggregateExpressions = finalAggregateExpressions,
+          nonCompleteAggregateAttributes = finalAggregateAttributes,
+          completeAggregateExpressions = completeAggregateExpressions,
+          completeAggregateAttributes = completeAggregateAttributes,
+          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          resultExpressions = resultExpressions,
+          child = partialMergeAggregate)
+      } else {
+        SortBasedAggregate(
+          requiredChildDistributionExpressions = Some(groupingAttributes),
+          groupingExpressions = groupingAttributes,
+          nonCompleteAggregateExpressions = finalAggregateExpressions,
+          nonCompleteAggregateAttributes = finalAggregateAttributes,
+          completeAggregateExpressions = completeAggregateExpressions,
+          completeAggregateAttributes = completeAggregateAttributes,
+          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          resultExpressions = resultExpressions,
+          child = partialMergeAggregate)
+      }
     }
 
     finalAndCompleteAggregate :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ed974b3..0cc4988 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
       }
       val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
       iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
-        Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+        0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
       val numPages = iter.getHashMap.getNumDataPages
       assert(numPages === 1)
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/4ace4f8a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 18bbdb9..a2ebf65 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -553,10 +553,16 @@ private[hive] case class HiveGenericUDTF(
 private[hive] case class HiveUDAFFunction(
     funcWrapper: HiveFunctionWrapper,
     children: Seq[Expression],
-    isUDAFBridgeRequired: Boolean = false)
+    isUDAFBridgeRequired: Boolean = false,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
   extends ImperativeAggregate with HiveInspectors {
 
-  def this() = this(null, null)
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
   @transient
   private lazy val resolver =
@@ -614,7 +620,11 @@ private[hive] case class HiveUDAFFunction(
     buffer = function.getNewAggregationBuffer
   }
 
-  override def aggBufferAttributes: Seq[AttributeReference] = Nil
+  override val aggBufferAttributes: Seq[AttributeReference] = Nil
+
+  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+  // in the superclass because that will lead to initialization ordering issues.
+  override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
 
   // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
   // catalyst type checking framework.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org