You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/09/01 05:19:13 UTC

spark git commit: revert PR#10896 and PR#14865

Repository: spark
Updated Branches:
  refs/heads/master 7a5000f39 -> aaf632b21


revert PR#10896 and PR#14865

## What changes were proposed in this pull request?

according to the discussion in the original PR #10896 and the new approach PR #14876 , we decided to revert these 2 PRs and go with the new approach.

## How was this patch tested?

N/A

Author: Wenchen Fan <we...@databricks.com>

Closes #14909 from cloud-fan/revert.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aaf632b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aaf632b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aaf632b2

Branch: refs/heads/master
Commit: aaf632b2132750c697dddd0469b902d9308dbf36
Parents: 7a5000f
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Sep 1 13:19:15 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Sep 1 13:19:15 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  17 +-
 .../sql/execution/aggregate/AggUtils.scala      | 250 ++++++++++---------
 .../sql/execution/aggregate/AggregateExec.scala |  56 -----
 .../execution/aggregate/HashAggregateExec.scala |  22 +-
 .../execution/aggregate/SortAggregateExec.scala |  24 +-
 .../execution/exchange/EnsureRequirements.scala |  39 +--
 .../org/apache/spark/sql/DataFrameSuite.scala   |  15 +-
 .../spark/sql/execution/PlannerSuite.scala      |  77 ++----
 8 files changed, 223 insertions(+), 277 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cda3b2b..4aaf454 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -259,17 +259,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         }
 
         val aggregateOperator =
-          if (functionsWithDistinct.isEmpty) {
+          if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+            if (functionsWithDistinct.nonEmpty) {
+              sys.error("Distinct columns cannot exist in Aggregate operator containing " +
+                "aggregate functions which don't support partial aggregation.")
+            } else {
+              aggregate.AggUtils.planAggregateWithoutPartial(
+                groupingExpressions,
+                aggregateExpressions,
+                resultExpressions,
+                planLater(child))
+            }
+          } else if (functionsWithDistinct.isEmpty) {
             aggregate.AggUtils.planAggregateWithoutDistinct(
               groupingExpressions,
               aggregateExpressions,
               resultExpressions,
               planLater(child))
           } else {
-            if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
-              sys.error("Distinct columns cannot exist in Aggregate operator containing " +
-                "aggregate functions which don't support partial aggregation.")
-            }
             aggregate.AggUtils.planAggregateWithOneDistinct(
               groupingExpressions,
               functionsWithDistinct,

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index fe75ece..4fbb9d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -19,97 +19,34 @@ package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
 
 /**
- * A pattern that finds aggregate operators to support partial aggregations.
- */
-object PartialAggregate {
-
-  def unapply(plan: SparkPlan): Option[Distribution] = plan match {
-    case agg: AggregateExec if AggUtils.supportPartialAggregate(agg.aggregateExpressions) =>
-      Some(agg.requiredChildDistribution.head)
-    case _ =>
-      None
-  }
-}
-
-/**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
  */
 object AggUtils {
 
-  def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
-    aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial)
-  }
-
-  private def createPartialAggregateExec(
+  def planAggregateWithoutPartial(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      child: SparkPlan): SparkPlan = {
-    val groupingAttributes = groupingExpressions.map(_.toAttribute)
-    val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
-    val partialAggregateExpressions = aggregateExpressions.map {
-      case agg @ AggregateExpression(_, _, false, _) if functionsWithDistinct.length > 0 =>
-        agg.copy(mode = PartialMerge)
-      case agg =>
-        agg.copy(mode = Partial)
-    }
-    val partialAggregateAttributes =
-      partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-    val partialResultExpressions =
-      groupingAttributes ++
-        partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
 
-    createAggregateExec(
-      requiredChildDistributionExpressions = None,
+    val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
+    val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
+    SortAggregateExec(
+      requiredChildDistributionExpressions = Some(groupingExpressions),
       groupingExpressions = groupingExpressions,
-      aggregateExpressions = partialAggregateExpressions,
-      aggregateAttributes = partialAggregateAttributes,
-      initialInputBufferOffset = if (functionsWithDistinct.length > 0) {
-        groupingExpressions.length + functionsWithDistinct.head.aggregateFunction.children.length
-      } else {
-        0
-      },
-      resultExpressions = partialResultExpressions,
-      child = child)
-  }
-
-  private def updateMergeAggregateMode(aggregateExpressions: Seq[AggregateExpression]) = {
-    def updateMode(mode: AggregateMode) = mode match {
-      case Partial => PartialMerge
-      case Complete => Final
-      case mode => mode
-    }
-    aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode)))
-  }
-
-  /**
-   * Builds new merge and map-side [[AggregateExec]]s from an input aggregate operator.
-   * If an aggregation needs a shuffle for satisfying its own distribution and supports partial
-   * aggregations, a map-side aggregation is appended before the shuffle in
-   * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]].
-   */
-  def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan) = operator match {
-    case agg: AggregateExec =>
-      val mapSideAgg = createPartialAggregateExec(
-        agg.groupingExpressions, agg.aggregateExpressions, agg.child)
-      val mergeAgg = createAggregateExec(
-        requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions,
-        groupingExpressions = agg.groupingExpressions.map(_.toAttribute),
-        aggregateExpressions = updateMergeAggregateMode(agg.aggregateExpressions),
-        aggregateAttributes = agg.aggregateAttributes,
-        initialInputBufferOffset = agg.groupingExpressions.length,
-        resultExpressions = agg.resultExpressions,
-        child = mapSideAgg
-      )
-
-      (mergeAgg, mapSideAgg)
+      aggregateExpressions = completeAggregateExpressions,
+      aggregateAttributes = completeAggregateAttributes,
+      initialInputBufferOffset = 0,
+      resultExpressions = resultExpressions,
+      child = child
+    ) :: Nil
   }
 
-  private def createAggregateExec(
+  private def createAggregate(
       requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
       groupingExpressions: Seq[NamedExpression] = Nil,
       aggregateExpressions: Seq[AggregateExpression] = Nil,
@@ -118,8 +55,7 @@ object AggUtils {
       resultExpressions: Seq[NamedExpression] = Nil,
       child: SparkPlan): SparkPlan = {
     val useHash = HashAggregateExec.supportsAggregate(
-      aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) &&
-      supportPartialAggregate(aggregateExpressions)
+      aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
     if (useHash) {
       HashAggregateExec(
         requiredChildDistributionExpressions = requiredChildDistributionExpressions,
@@ -146,21 +82,43 @@ object AggUtils {
       aggregateExpressions: Seq[AggregateExpression],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
+    // Check if we can use HashAggregate.
+
+    // 1. Create an Aggregate Operator for partial aggregations.
+
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
-    val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
-    val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
-    val supportPartial = supportPartialAggregate(aggregateExpressions)
+    val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
+    val partialAggregateAttributes =
+      partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+    val partialResultExpressions =
+      groupingAttributes ++
+        partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 
-    createAggregateExec(
-      requiredChildDistributionExpressions =
-        Some(if (supportPartial) groupingAttributes else groupingExpressions),
-      groupingExpressions = groupingExpressions,
-      aggregateExpressions = completeAggregateExpressions,
-      aggregateAttributes = completeAggregateAttributes,
-      initialInputBufferOffset = 0,
-      resultExpressions = resultExpressions,
-      child = child
-    ) :: Nil
+    val partialAggregate = createAggregate(
+        requiredChildDistributionExpressions = None,
+        groupingExpressions = groupingExpressions,
+        aggregateExpressions = partialAggregateExpressions,
+        aggregateAttributes = partialAggregateAttributes,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialResultExpressions,
+        child = child)
+
+    // 2. Create an Aggregate Operator for final aggregations.
+    val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
+    // The attributes of the final aggregation buffer, which is presented as input to the result
+    // projection:
+    val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
+
+    val finalAggregate = createAggregate(
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = finalAggregateExpressions,
+        aggregateAttributes = finalAggregateAttributes,
+        initialInputBufferOffset = groupingExpressions.length,
+        resultExpressions = resultExpressions,
+        child = partialAggregate)
+
+    finalAggregate :: Nil
   }
 
   def planAggregateWithOneDistinct(
@@ -183,23 +141,39 @@ object AggUtils {
     val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
-    // 1. Create an Aggregate Operator for non-distinct aggregations.
+    // 1. Create an Aggregate Operator for partial aggregations.
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregateExec(
+      // We will group by the original grouping expression, plus an additional expression for the
+      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+      // expressions will be [key, value].
+      createAggregate(
+        groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        resultExpressions = groupingAttributes ++ distinctAttributes ++
+          aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = child)
+    }
+
+    // 2. Create an Aggregate Operator for partial merge aggregations.
+    val partialMergeAggregate: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
         requiredChildDistributionExpressions =
           Some(groupingAttributes ++ distinctAttributes),
-        groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+        groupingExpressions = groupingAttributes ++ distinctAttributes,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
         resultExpressions = groupingAttributes ++ distinctAttributes ++
           aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = child)
+        child = partialAggregate)
     }
 
-    // 2. Create an Aggregate Operator for the final aggregation.
+    // 3. Create an Aggregate operator for partial aggregation (for distinct)
     val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap
     val rewrittenDistinctFunctions = functionsWithDistinct.map {
       // Children of an AggregateFunction with DISTINCT keyword has already
@@ -209,6 +183,38 @@ object AggUtils {
         aggregateFunction.transformDown(distinctColumnAttributeLookup)
           .asInstanceOf[AggregateFunction]
     }
+
+    val partialDistinctAggregate: SparkPlan = {
+      val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      // The attributes of the final aggregation buffer, which is presented as input to the result
+      // projection:
+      val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute)
+      val (distinctAggregateExpressions, distinctAggregateAttributes) =
+        rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
+          // We rewrite the aggregate function to a non-distinct aggregation because
+          // its input will have distinct arguments.
+          // We just keep the isDistinct setting to true, so when users look at the query plan,
+          // they still can see distinct aggregations.
+          val expr = AggregateExpression(func, Partial, isDistinct = true)
+          // Use original AggregationFunction to lookup attributes, which is used to build
+          // aggregateFunctionToAttribute
+          val attr = functionsWithDistinct(i).resultAttribute
+          (expr, attr)
+      }.unzip
+
+      val partialAggregateResult = groupingAttributes ++
+          mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++
+          distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+      createAggregate(
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions,
+        aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes,
+        initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
+        resultExpressions = partialAggregateResult,
+        child = partialMergeAggregate)
+    }
+
+    // 4. Create an Aggregate Operator for the final aggregation.
     val finalAndCompleteAggregate: SparkPlan = {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
       // The attributes of the final aggregation buffer, which is presented as input to the result
@@ -219,23 +225,23 @@ object AggUtils {
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
           // We rewrite the aggregate function to a non-distinct aggregation because
           // its input will have distinct arguments.
-          // We keep the isDistinct setting to true because this flag is used to generate partial
-          // aggregations and it is easy to see aggregation types in the query plan.
-          val expr = AggregateExpression(func, Complete, isDistinct = true)
+          // We just keep the isDistinct setting to true, so when users look at the query plan,
+          // they still can see distinct aggregations.
+          val expr = AggregateExpression(func, Final, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is used to build
           // aggregateFunctionToAttribute
           val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
-        }.unzip
+      }.unzip
 
-      createAggregateExec(
+      createAggregate(
         requiredChildDistributionExpressions = Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
         aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions,
         aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes,
         initialInputBufferOffset = groupingAttributes.length,
         resultExpressions = resultExpressions,
-        child = partialAggregate)
+        child = partialDistinctAggregate)
     }
 
     finalAndCompleteAggregate :: Nil
@@ -243,14 +249,13 @@ object AggUtils {
 
   /**
    * Plans a streaming aggregation using the following progression:
-   *  - Partial Aggregation (now there is at most 1 tuple per group)
+   *  - Partial Aggregation
+   *  - Shuffle
+   *  - Partial Merge (now there is at most 1 tuple per group)
    *  - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
    *  - PartialMerge (now there is at most 1 tuple per group)
    *  - StateStoreSave (saves the tuple for the next batch)
    *  - Complete (output the current result of the aggregation)
-   *
-   *  If the first aggregation needs a shuffle to satisfy its distribution, a map-side partial
-   *  an aggregation and a shuffle are added in `EnsureRequirements`.
    */
   def planStreamingAggregation(
       groupingExpressions: Seq[NamedExpression],
@@ -263,24 +268,39 @@ object AggUtils {
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregateExec(
+      // We will group by the original grouping expression, plus an additional expression for the
+      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+      // expressions will be [key, value].
+      createAggregate(
+        groupingExpressions = groupingExpressions,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        resultExpressions = groupingAttributes ++
+            aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = child)
+    }
+
+    val partialMerged1: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
         requiredChildDistributionExpressions =
             Some(groupingAttributes),
-        groupingExpressions = groupingExpressions,
+        groupingExpressions = groupingAttributes,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         initialInputBufferOffset = groupingAttributes.length,
         resultExpressions = groupingAttributes ++
             aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = child)
+        child = partialAggregate)
     }
 
-    val restored = StateStoreRestoreExec(groupingAttributes, None, partialAggregate)
+    val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
 
-    val partialMerged: SparkPlan = {
+    val partialMerged2: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregateExec(
+      createAggregate(
         requiredChildDistributionExpressions =
             Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
@@ -294,7 +314,7 @@ object AggUtils {
     // Note: stateId and returnAllStates are filled in later with preparation rules
     // in IncrementalExecution.
     val saved = StateStoreSaveExec(
-      groupingAttributes, stateId = None, returnAllStates = None, partialMerged)
+      groupingAttributes, stateId = None, returnAllStates = None, partialMerged2)
 
     val finalAndCompleteAggregate: SparkPlan = {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
@@ -302,7 +322,7 @@ object AggUtils {
       // projection:
       val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
 
-      createAggregateExec(
+      createAggregate(
         requiredChildDistributionExpressions = Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
         aggregateExpressions = finalAggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
deleted file mode 100644
index b88a8aa..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.UnaryExecNode
-
-/**
- * A base class for aggregate implementation.
- */
-abstract class AggregateExec extends UnaryExecNode {
-
-  def requiredChildDistributionExpressions: Option[Seq[Expression]]
-  def groupingExpressions: Seq[NamedExpression]
-  def aggregateExpressions: Seq[AggregateExpression]
-  def aggregateAttributes: Seq[Attribute]
-  def initialInputBufferOffset: Int
-  def resultExpressions: Seq[NamedExpression]
-
-  protected[this] val aggregateBufferAttributes = {
-    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-  }
-
-  override def producedAttributes: AttributeSet =
-    AttributeSet(aggregateAttributes) ++
-      AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
-      AttributeSet(aggregateBufferAttributes)
-
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
-      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 525c7e3..bd7efa6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -41,7 +42,11 @@ case class HashAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends AggregateExec with CodegenSupport {
+  extends UnaryExecNode with CodegenSupport {
+
+  private[this] val aggregateBufferAttributes = {
+    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+  }
 
   require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
 
@@ -55,6 +60,21 @@ case class HashAggregateExec(
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
     "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
 
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def producedAttributes: AttributeSet =
+    AttributeSet(aggregateAttributes) ++
+    AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+    AttributeSet(aggregateBufferAttributes)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
   // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
   // map and/or the sort-based aggregation once it has processed a given number of input rows.
   private val testFallbackStartsAt: Option[(Int, Int)] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 68f86fc..2a81a82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.Utils
 
@@ -37,11 +38,30 @@ case class SortAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends AggregateExec {
+  extends UnaryExecNode {
+
+  private[this] val aggregateBufferAttributes = {
+    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+  }
+
+  override def producedAttributes: AttributeSet =
+    AttributeSet(aggregateAttributes) ++
+      AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+      AttributeSet(aggregateBufferAttributes)
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
     groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 66e99de..f170499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.aggregate.AggUtils
-import org.apache.spark.sql.execution.aggregate.PartialAggregate
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -153,31 +151,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
   private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
     val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
     val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
-    assert(requiredChildDistributions.length == operator.children.length)
-    assert(requiredChildOrderings.length == operator.children.length)
+    var children: Seq[SparkPlan] = operator.children
+    assert(requiredChildDistributions.length == children.length)
+    assert(requiredChildOrderings.length == children.length)
 
-    def createShuffleExchange(dist: Distribution, child: SparkPlan) =
-      ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)
-
-    var (parent, children) = operator match {
-      case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
-        // If an aggregation needs a shuffle and support partial aggregations, a map-side partial
-        // aggregation and a shuffle are added as children.
-        val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
-        (mergeAgg, createShuffleExchange(
-          requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
-      case _ =>
-        // Ensure that the operator's children satisfy their output distribution requirements:
-        val childrenWithDist = operator.children.zip(requiredChildDistributions)
-        val newChildren = childrenWithDist.map {
-          case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
-            child
-          case (child, BroadcastDistribution(mode)) =>
-            BroadcastExchangeExec(mode, child)
-          case (child, distribution) =>
-            createShuffleExchange(distribution, child)
-        }
-        (operator, newChildren)
+    // Ensure that the operator's children satisfy their output distribution requirements:
+    children = children.zip(requiredChildDistributions).map {
+      case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+        child
+      case (child, BroadcastDistribution(mode)) =>
+        BroadcastExchangeExec(mode, child)
+      case (child, distribution) =>
+        ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
     }
 
     // If the operator has multiple children and specifies child output distributions (e.g. join),
@@ -270,7 +255,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
       }
     }
 
-    parent.withNewChildren(children)
+    operator.withNewChildren(children)
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ce0b92a..f899517 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   /**
-   * Verifies that there is a single Aggregation for `df`
+   * Verifies that there is no Exchange between the Aggregations for `df`
    */
-  private def verifyNonExchangingSingleAgg(df: DataFrame) = {
+  private def verifyNonExchangingAgg(df: DataFrame) = {
     var atFirstAgg: Boolean = false
     df.queryExecution.executedPlan.foreach {
       case agg: HashAggregateExec =>
+        atFirstAgg = !atFirstAgg
+      case _ =>
         if (atFirstAgg) {
-          fail("Should not have back to back Aggregates")
+          fail("Should not have operators between the two aggregations")
         }
-        atFirstAgg = true
-      case _ =>
     }
   }
 
@@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     // Group by the column we are distributed by. This should generate a plan with no exchange
     // between the aggregates
     val df3 = testData.repartition($"key").groupBy("key").count()
-    verifyNonExchangingSingleAgg(df3)
-    verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
+    verifyNonExchangingAgg(df3)
+    verifyNonExchangingAgg(testData.repartition($"key", $"value")
       .groupBy("key", "value").count())
-    verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
 
     // Grouping by just the first distributeBy expr, need to exchange.
     verifyExchangingAgg(testData.repartition($"key", $"value")

http://git-wip-us.apache.org/repos/asf/spark/blob/aaf632b2/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index b0aa337..375da22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.aggregate.SortAggregateExec
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -38,84 +37,36 @@ class PlannerSuite extends SharedSQLContext {
 
   setupTestData()
 
-  private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
+  private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
     val planner = spark.sessionState.planner
     import planner._
-    val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
-    val planned = Aggregation(query).headOption.map(ensureRequirements(_))
-      .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
-    planned.collect { case n if n.nodeName contains "Aggregate" => n }
+    val plannedOption = Aggregation(query).headOption
+    val planned =
+      plannedOption.getOrElse(
+        fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+    val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+    // For the new aggregation code path, there will be four aggregate operator for
+    // distinct aggregations.
+    assert(
+      aggregations.size == 2 || aggregations.size == 4,
+      s"The plan of query $query does not have partial aggregations.")
   }
 
   test("count is partially aggregated") {
     val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
-    assert(testPartialAggregationPlan(query).size == 2,
-      s"The plan of query $query does not have partial aggregations.")
+    testPartialAggregationPlan(query)
   }
 
   test("count distinct is partially aggregated") {
     val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
     testPartialAggregationPlan(query)
-    // For the new aggregation code path, there will be four aggregate operator for  distinct
-    // aggregations.
-    assert(testPartialAggregationPlan(query).size == 4,
-      s"The plan of query $query does not have partial aggregations.")
   }
 
   test("mixed aggregates are partially aggregated") {
     val query =
       testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
-    // For the new aggregation code path, there will be four aggregate operator for  distinct
-    // aggregations.
-    assert(testPartialAggregationPlan(query).size == 4,
-      s"The plan of query $query does not have partial aggregations.")
-  }
-
-  test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
-    withTempView("testSortBasedPartialAggregation") {
-      val schema = StructType(
-        StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
-      val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
-      spark.createDataFrame(rowRDD, schema)
-        .createOrReplaceTempView("testSortBasedPartialAggregation")
-
-      // This test assumes a query below uses sort-based aggregations
-      val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
-        .queryExecution.executedPlan
-      // This line extracts both SortAggregate and Sort operators
-      val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
-      val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
-      assert(extractedOps.size == 4 && aggOps.size == 2,
-        s"The plan $planned does not have correct sort-based partial aggregate pairs.")
-    }
-  }
-
-  test("non-partial aggregation for aggregates") {
-    withTempView("testNonPartialAggregation") {
-      val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
-      val row = Row.fromSeq(Seq.fill(1)(null))
-      val rowRDD = sparkContext.parallelize(row :: Nil)
-      spark.createDataFrame(rowRDD, schema).repartition($"value")
-        .createOrReplaceTempView("testNonPartialAggregation")
-
-      val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
-        .queryExecution.executedPlan
-
-      // If input data are already partitioned and the same columns are used in grouping keys and
-      // aggregation values, no partial aggregation exist in query plans.
-      val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
-      assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
-
-      val planned2 = sql(
-        """
-          |SELECT t.value, SUM(DISTINCT t.value)
-          |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
-          |GROUP BY t.value
-        """.stripMargin).queryExecution.executedPlan
-
-      val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
-      assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
-    }
+    testPartialAggregationPlan(query)
   }
 
   test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org