You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/10/08 23:56:39 UTC

spark git commit: [SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic

Repository: spark
Updated Branches:
  refs/heads/master 9e66a53c9 -> 2816c89b6


[SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic

In `aggregate/utils.scala`, there is a substantial amount of duplication in the expression-rewriting logic. As a prerequisite to supporting imperative aggregate functions in `TungstenAggregate`, this patch refactors this file so that the same expression-rewriting logic is used for both `SortAggregate` and `TungstenAggregate`.

In order to allow both operators to use the same rewriting logic, `TungstenAggregationIterator. generateResultProjection()` has been updated so that it first evaluates all declarative aggregate functions' `evaluateExpression`s and writes the results into a temporary buffer, and then uses this temporary buffer and the grouping expressions to evaluate the final resultExpressions. This matches the logic in SortAggregateIterator, where this two-pass approach is necessary in order to support imperative aggregates. If this change turns out to cause performance regressions, then we can look into re-implementing the single-pass evaluation in a cleaner way as part of a followup patch.

Since the rewriting logic is now shared across both operators, this patch also extracts that logic and places it in `SparkStrategies`. This makes the rewriting logic a bit easier to follow, I think.

Author: Josh Rosen <jo...@databricks.com>

Closes #9015 from JoshRosen/SPARK-10988.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2816c89b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2816c89b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2816c89b

Branch: refs/heads/master
Commit: 2816c89b6a304cb0b5214e14ebbc320158e88260
Parents: 9e66a53
Author: Josh Rosen <jo...@databricks.com>
Authored: Thu Oct 8 14:53:21 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Oct 8 14:56:27 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  67 +++--
 .../execution/aggregate/TungstenAggregate.scala |   4 +
 .../aggregate/TungstenAggregationIterator.scala |  22 +-
 .../spark/sql/execution/aggregate/utils.scala   | 244 ++++++-------------
 .../TungstenAggregationIteratorSuite.scala      |   2 +-
 5 files changed, 143 insertions(+), 196 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2816c89b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index d1bbf2e..79bd1a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         converted match {
           case None => Nil // Cannot convert to new aggregation code path.
           case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
-            // Extracts all distinct aggregate expressions from the resultExpressions.
+            // A single aggregate expression might appear multiple times in resultExpressions.
+            // In order to avoid evaluating an individual aggregate function multiple times, we'll
+            // build a set of the distinct aggregate expressions and build a function which can
+            // be used to re-write expressions so that they reference the single copy of the
+            // aggregate function which actually gets computed.
             val aggregateExpressions = resultExpressions.flatMap { expr =>
               expr.collect {
                 case agg: AggregateExpression2 => agg
               }
-            }.toSet.toSeq
+            }.distinct
             // For those distinct aggregate expressions, we create a map from the
             // aggregate function to the corresponding attribute of the function.
-            val aggregateFunctionMap = aggregateExpressions.map { agg =>
+            val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
               val aggregateFunction = agg.aggregateFunction
-              val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
-              (aggregateFunction, agg.isDistinct) ->
-                (aggregateFunction -> attribtue)
+              val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+              (aggregateFunction, agg.isDistinct) -> attribute
             }.toMap
 
             val (functionsWithDistinct, functionsWithoutDistinct) =
@@ -220,6 +223,40 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
                   "code path.")
             }
 
+            val namedGroupingExpressions = groupingExpressions.map {
+              case ne: NamedExpression => ne -> ne
+              // If the expression is not a NamedExpressions, we add an alias.
+              // So, when we generate the result of the operator, the Aggregate Operator
+              // can directly get the Seq of attributes representing the grouping expressions.
+              case other =>
+                val withAlias = Alias(other, other.toString)()
+                other -> withAlias
+            }
+            val groupExpressionMap = namedGroupingExpressions.toMap
+
+            // The original `resultExpressions` are a set of expressions which may reference
+            // aggregate expressions, grouping column values, and constants. When aggregate operator
+            // emits output rows, we will use `resultExpressions` to generate an output projection
+            // which takes the grouping columns and final aggregate result buffer as input.
+            // Thus, we must re-write the result expressions so that their attributes match up with
+            // the attributes of the final result projection's input row:
+            val rewrittenResultExpressions = resultExpressions.map { expr =>
+              expr.transformDown {
+                case AggregateExpression2(aggregateFunction, _, isDistinct) =>
+                  // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+                  // so replace each aggregate expression by its corresponding attribute in the set:
+                  aggregateFunctionToAttribute(aggregateFunction, isDistinct)
+                case expression =>
+                  // Since we're using `namedGroupingAttributes` to extract the grouping key
+                  // columns, we need to replace grouping key expressions with their corresponding
+                  // attributes. We do not rely on the equality check at here since attributes may
+                  // differ cosmetically. Instead, we use semanticEquals.
+                  groupExpressionMap.collectFirst {
+                    case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+                  }.getOrElse(expression)
+              }.asInstanceOf[NamedExpression]
+            }
+
             val aggregateOperator =
               if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
                 if (functionsWithDistinct.nonEmpty) {
@@ -227,26 +264,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
                     "aggregate functions which don't support partial aggregation.")
                 } else {
                   aggregate.Utils.planAggregateWithoutPartial(
-                    groupingExpressions,
+                    namedGroupingExpressions.map(_._2),
                     aggregateExpressions,
-                    aggregateFunctionMap,
-                    resultExpressions,
+                    aggregateFunctionToAttribute,
+                    rewrittenResultExpressions,
                     planLater(child))
                 }
               } else if (functionsWithDistinct.isEmpty) {
                 aggregate.Utils.planAggregateWithoutDistinct(
-                  groupingExpressions,
+                  namedGroupingExpressions.map(_._2),
                   aggregateExpressions,
-                  aggregateFunctionMap,
-                  resultExpressions,
+                  aggregateFunctionToAttribute,
+                  rewrittenResultExpressions,
                   planLater(child))
               } else {
                 aggregate.Utils.planAggregateWithOneDistinct(
-                  groupingExpressions,
+                  namedGroupingExpressions.map(_._2),
                   functionsWithDistinct,
                   functionsWithoutDistinct,
-                  aggregateFunctionMap,
-                  resultExpressions,
+                  aggregateFunctionToAttribute,
+                  rewrittenResultExpressions,
                   planLater(child))
               }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2816c89b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 3cd22af..7b3d072 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -31,7 +31,9 @@ case class TungstenAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
     groupingExpressions: Seq[NamedExpression],
     nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
   extends UnaryNode {
@@ -77,7 +79,9 @@ case class TungstenAggregate(
       new TungstenAggregationIterator(
         groupingExpressions,
         nonCompleteAggregateExpressions,
+        nonCompleteAggregateAttributes,
         completeAggregateExpressions,
+        completeAggregateAttributes,
         resultExpressions,
         newMutableProjection,
         child.output,

http://git-wip-us.apache.org/repos/asf/spark/blob/2816c89b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index a6f4c1d..4bb95c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType
  * @param nonCompleteAggregateExpressions
  *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
  *   [[PartialMerge]], or [[Final]].
+ * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
+ *   outputs when they are stored in the final aggregation buffer.
  * @param completeAggregateExpressions
  *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
+ *   when they are stored in the final aggregation buffer.
  * @param resultExpressions
  *   expressions for generating output rows.
  * @param newMutableProjection
@@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType
 class TungstenAggregationIterator(
     groupingExpressions: Seq[NamedExpression],
     nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
     originalInputAttributes: Seq[Attribute],
@@ -280,17 +286,25 @@ class TungstenAggregationIterator(
       // resultExpressions.
       case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
         val joinedRow = new JoinedRow()
+        val evalExpressions = allAggregateFunctions.map {
+          case ae: DeclarativeAggregate => ae.evaluateExpression
+          // case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
+        }
+        val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
+        // These are the attributes of the row produced by `expressionAggEvalProjection`
+        val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
         val resultProjection =
-          UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
+          UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
 
         (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
-          resultProjection(joinedRow(currentGroupingKey, currentBuffer))
+          // Generate results for all expression-based aggregate functions.
+          val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
+          resultProjection(joinedRow(currentGroupingKey, aggregateResult))
         }
 
       // Grouping-only: a output row is generated from values of grouping expressions.
       case (None, None) =>
-        val resultProjection =
-          UnsafeProjection.create(resultExpressions, groupingAttributes)
+        val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
 
         (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
           resultProjection(currentGroupingKey)

http://git-wip-us.apache.org/repos/asf/spark/blob/2816c89b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index e1c2d94..cf6e7ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import scala.collection.mutable
-
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
@@ -38,60 +36,35 @@ object Utils {
   }
 
   def planAggregateWithoutPartial(
-      groupingExpressions: Seq[Expression],
+      groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+      aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
-    val namedGroupingExpressions = groupingExpressions.map {
-      case ne: NamedExpression => ne -> ne
-      // If the expression is not a NamedExpressions, we add an alias.
-      // So, when we generate the result of the operator, the Aggregate Operator
-      // can directly get the Seq of attributes representing the grouping expressions.
-      case other =>
-        val withAlias = Alias(other, other.toString)()
-        other -> withAlias
-    }
-    val groupExpressionMap = namedGroupingExpressions.toMap
-    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
     val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
-    val completeAggregateAttributes =
-      completeAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
-      }
-
-    val rewrittenResultExpressions = resultExpressions.map { expr =>
-      expr.transformDown {
-        case agg: AggregateExpression2 =>
-          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
-        case expression =>
-          // We do not rely on the equality check at here since attributes may
-          // different cosmetically. Instead, we use semanticEquals.
-          groupExpressionMap.collectFirst {
-            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-          }.getOrElse(expression)
-      }.asInstanceOf[NamedExpression]
+    val completeAggregateAttributes = completeAggregateExpressions.map {
+      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
     }
 
     SortBasedAggregate(
-      requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-      groupingExpressions = namedGroupingExpressions.map(_._2),
+      requiredChildDistributionExpressions = Some(groupingAttributes),
+      groupingExpressions = groupingAttributes,
       nonCompleteAggregateExpressions = Nil,
       nonCompleteAggregateAttributes = Nil,
       completeAggregateExpressions = completeAggregateExpressions,
       completeAggregateAttributes = completeAggregateAttributes,
       initialInputBufferOffset = 0,
-      resultExpressions = rewrittenResultExpressions,
+      resultExpressions = resultExpressions,
       child = child
     ) :: Nil
   }
 
   def planAggregateWithoutDistinct(
-      groupingExpressions: Seq[Expression],
+      groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+      aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
     // Check if we can use TungstenAggregate.
@@ -104,36 +77,29 @@ object Utils {
 
 
     // 1. Create an Aggregate Operator for partial aggregations.
-    val namedGroupingExpressions = groupingExpressions.map {
-      case ne: NamedExpression => ne -> ne
-      // If the expression is not a NamedExpressions, we add an alias.
-      // So, when we generate the result of the operator, the Aggregate Operator
-      // can directly get the Seq of attributes representing the grouping expressions.
-      case other =>
-        val withAlias = Alias(other, other.toString)()
-        other -> withAlias
-    }
-    val groupExpressionMap = namedGroupingExpressions.toMap
-    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
     val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
     val partialAggregateAttributes =
       partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
     val partialResultExpressions =
-      namedGroupingAttributes ++
+      groupingAttributes ++
         partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 
     val partialAggregate = if (usesTungstenAggregate) {
       TungstenAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
-        groupingExpressions = namedGroupingExpressions.map(_._2),
+        groupingExpressions = groupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
+        nonCompleteAggregateAttributes = partialAggregateAttributes,
         completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
         resultExpressions = partialResultExpressions,
         child = child)
     } else {
       SortBasedAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
-        groupingExpressions = namedGroupingExpressions.map(_._2),
+        groupingExpressions = groupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
         nonCompleteAggregateAttributes = partialAggregateAttributes,
         completeAggregateExpressions = Nil,
@@ -145,58 +111,32 @@ object Utils {
 
     // 2. Create an Aggregate Operator for final aggregations.
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
-    val finalAggregateAttributes =
-      finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
-      }
+    // The attributes of the final aggregation buffer, which is presented as input to the result
+    // projection:
+    val finalAggregateAttributes = finalAggregateExpressions.map {
+      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+    }
 
     val finalAggregate = if (usesTungstenAggregate) {
-      val rewrittenResultExpressions = resultExpressions.map { expr =>
-        expr.transformDown {
-          case agg: AggregateExpression2 =>
-            // aggregateFunctionMap contains unique aggregate functions.
-            val aggregateFunction =
-              aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
-            aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression
-          case expression =>
-            // We do not rely on the equality check at here since attributes may
-            // different cosmetically. Instead, we use semanticEquals.
-            groupExpressionMap.collectFirst {
-              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-            }.getOrElse(expression)
-        }.asInstanceOf[NamedExpression]
-      }
-
       TungstenAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
+        nonCompleteAggregateAttributes = finalAggregateAttributes,
         completeAggregateExpressions = Nil,
-        resultExpressions = rewrittenResultExpressions,
+        completeAggregateAttributes = Nil,
+        resultExpressions = resultExpressions,
         child = partialAggregate)
     } else {
-      val rewrittenResultExpressions = resultExpressions.map { expr =>
-        expr.transformDown {
-          case agg: AggregateExpression2 =>
-            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
-          case expression =>
-            // We do not rely on the equality check at here since attributes may
-            // different cosmetically. Instead, we use semanticEquals.
-            groupExpressionMap.collectFirst {
-              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-            }.getOrElse(expression)
-        }.asInstanceOf[NamedExpression]
-      }
-
       SortBasedAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
         nonCompleteAggregateAttributes = finalAggregateAttributes,
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
-        initialInputBufferOffset = namedGroupingAttributes.length,
-        resultExpressions = rewrittenResultExpressions,
+        initialInputBufferOffset = groupingExpressions.length,
+        resultExpressions = resultExpressions,
         child = partialAggregate)
     }
 
@@ -204,10 +144,10 @@ object Utils {
   }
 
   def planAggregateWithOneDistinct(
-      groupingExpressions: Seq[Expression],
+      groupingExpressions: Seq[NamedExpression],
       functionsWithDistinct: Seq[AggregateExpression2],
       functionsWithoutDistinct: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+      aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
@@ -221,20 +161,7 @@ object Utils {
           aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
 
     // 1. Create an Aggregate Operator for partial aggregations.
-    // The grouping expressions are original groupingExpressions and
-    // distinct columns. For example, for avg(distinct value) ... group by key
-    // the grouping expressions of this Aggregate Operator will be [key, value].
-    val namedGroupingExpressions = groupingExpressions.map {
-      case ne: NamedExpression => ne -> ne
-      // If the expression is not a NamedExpressions, we add an alias.
-      // So, when we generate the result of the operator, the Aggregate Operator
-      // can directly get the Seq of attributes representing the grouping expressions.
-      case other =>
-        val withAlias = Alias(other, other.toString)()
-        other -> withAlias
-    }
-    val groupExpressionMap = namedGroupingExpressions.toMap
-    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
     // It is safe to call head at here since functionsWithDistinct has at least one
     // AggregateExpression2.
@@ -253,22 +180,27 @@ object Utils {
     val partialAggregateAttributes =
       partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
     val partialAggregateGroupingExpressions =
-      (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
+      groupingExpressions ++ namedDistinctColumnExpressions.map(_._2)
     val partialAggregateResult =
-      namedGroupingAttributes ++
+        groupingAttributes ++
         distinctColumnAttributes ++
         partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
     val partialAggregate = if (usesTungstenAggregate) {
       TungstenAggregate(
-        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        requiredChildDistributionExpressions = None,
+        // The grouping expressions are original groupingExpressions and
+        // distinct columns. For example, for avg(distinct value) ... group by key
+        // the grouping expressions of this Aggregate Operator will be [key, value].
         groupingExpressions = partialAggregateGroupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
+        nonCompleteAggregateAttributes = partialAggregateAttributes,
         completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
         resultExpressions = partialAggregateResult,
         child = child)
     } else {
       SortBasedAggregate(
-        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        requiredChildDistributionExpressions = None,
         groupingExpressions = partialAggregateGroupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
         nonCompleteAggregateAttributes = partialAggregateAttributes,
@@ -284,41 +216,40 @@ object Utils {
     val partialMergeAggregateAttributes =
       partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
     val partialMergeAggregateResult =
-      namedGroupingAttributes ++
+        groupingAttributes ++
         distinctColumnAttributes ++
         partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
     val partialMergeAggregate = if (usesTungstenAggregate) {
       TungstenAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
         nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+        nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
         completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
         resultExpressions = partialMergeAggregateResult,
         child = partialAggregate)
     } else {
       SortBasedAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
         nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
         nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
-        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
         resultExpressions = partialMergeAggregateResult,
         child = partialAggregate)
     }
 
     // 3. Create an Aggregate Operator for partial merge aggregations.
     val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
-    val finalAggregateAttributes =
-      finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
-      }
-    // Create a map to store those rewritten aggregate functions. We always need to use
-    // both function and its corresponding isDistinct flag as the key because function itself
-    // does not knows if it is has distinct keyword or now.
-    val rewrittenAggregateFunctions =
-      mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
+    // The attributes of the final aggregation buffer, which is presented as input to the result
+    // projection:
+    val finalAggregateAttributes = finalAggregateExpressions.map {
+      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+    }
+
     val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
       // Children of an AggregateFunction with DISTINCT keyword has already
       // been evaluated. At here, we need to replace original children
@@ -328,9 +259,6 @@ object Utils {
           case expr if distinctColumnExpressionMap.contains(expr) =>
             distinctColumnExpressionMap(expr).toAttribute
         }.asInstanceOf[AggregateFunction2]
-        // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
-        // to track the old version and the new version of this function.
-        rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
         // We rewrite the aggregate function to a non-distinct aggregation because
         // its input will have distinct arguments.
         // We just keep the isDistinct setting to true, so when users look at the query plan,
@@ -338,66 +266,30 @@ object Utils {
         val rewrittenAggregateExpression =
           AggregateExpression2(rewrittenAggregateFunction, Complete, true)
 
-        val aggregateFunctionAttribute =
-          aggregateFunctionMap(agg.aggregateFunction, true)._2
-        (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+        val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
+        (rewrittenAggregateExpression, aggregateFunctionAttribute)
     }.unzip
 
     val finalAndCompleteAggregate = if (usesTungstenAggregate) {
-      val rewrittenResultExpressions = resultExpressions.map { expr =>
-        expr.transform {
-          case agg: AggregateExpression2 =>
-            val function = agg.aggregateFunction
-            val isDistinct = agg.isDistinct
-            val aggregateFunction =
-              if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
-                // If this function has been rewritten, we get the rewritten version from
-                // rewrittenAggregateFunctions.
-                rewrittenAggregateFunctions(function, isDistinct)
-              } else {
-                // Oterwise, we get it from aggregateFunctionMap, which contains unique
-                // aggregate functions that have not been rewritten.
-                aggregateFunctionMap(function, isDistinct)._1
-              }
-            aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression
-          case expression =>
-            // We do not rely on the equality check at here since attributes may
-            // different cosmetically. Instead, we use semanticEquals.
-            groupExpressionMap.collectFirst {
-              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-            }.getOrElse(expression)
-        }.asInstanceOf[NamedExpression]
-      }
-
       TungstenAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
+        nonCompleteAggregateAttributes = finalAggregateAttributes,
         completeAggregateExpressions = completeAggregateExpressions,
-        resultExpressions = rewrittenResultExpressions,
+        completeAggregateAttributes = completeAggregateAttributes,
+        resultExpressions = resultExpressions,
         child = partialMergeAggregate)
     } else {
-      val rewrittenResultExpressions = resultExpressions.map { expr =>
-        expr.transform {
-          case agg: AggregateExpression2 =>
-            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
-          case expression =>
-            // We do not rely on the equality check at here since attributes may
-            // different cosmetically. Instead, we use semanticEquals.
-            groupExpressionMap.collectFirst {
-              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-            }.getOrElse(expression)
-        }.asInstanceOf[NamedExpression]
-      }
       SortBasedAggregate(
-        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
-        groupingExpressions = namedGroupingAttributes,
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
         nonCompleteAggregateAttributes = finalAggregateAttributes,
         completeAggregateExpressions = completeAggregateExpressions,
         completeAggregateAttributes = completeAggregateAttributes,
-        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
-        resultExpressions = rewrittenResultExpressions,
+        initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = resultExpressions,
         child = partialMergeAggregate)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2816c89b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index 7ca677a..ed974b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
         () => new InterpretedMutableProjection(expr, schema)
       }
       val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
-      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty,
+      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
         Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
       val numPages = iter.getHashMap.getNumDataPages
       assert(numPages === 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org