You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/07/30 19:32:22 UTC
spark git commit: [SPARK-9361] [SQL] Refactor new aggregation code to
reduce the times of checking compatibility
Repository: spark
Updated Branches:
refs/heads/master 7bbf02f0b -> 5363ed715
[SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility
JIRA: https://issues.apache.org/jira/browse/SPARK-9361
Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate` and reuse it.
In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert` should be moved to catalyst because it actually doesn't deal with execution details.
Author: Liang-Chi Hsieh <vi...@appier.com>
Closes #7677 from viirya/refactor_aggregate and squashes the following commits:
babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate
9a589d7 [Liang-Chi Hsieh] Fix scala style.
0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5363ed71
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5363ed71
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5363ed71
Branch: refs/heads/master
Commit: 5363ed71568c3e7c082146d654a9c669d692d894
Parents: 7bbf02f
Author: Liang-Chi Hsieh <vi...@appier.com>
Authored: Thu Jul 30 10:30:37 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Jul 30 10:32:12 2015 -0700
----------------------------------------------------------------------
.../expressions/aggregate/interfaces.scala | 4 +-
.../catalyst/expressions/aggregate/utils.scala | 167 +++++++++++++++++++
.../catalyst/plans/logical/basicOperators.scala | 3 +
.../spark/sql/execution/SparkStrategies.scala | 34 ++--
.../spark/sql/execution/aggregate/utils.scala | 144 ----------------
5 files changed, 188 insertions(+), 164 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9fb7623..d08f553 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
@@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
private[sql] case object Final extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
new file mode 100644
index 0000000..4a43318
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+ val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+ case array: ArrayType => true
+ case map: MapType => true
+ case struct: StructType => true
+ case _ => false
+ }
+
+ !hasComplexTypes
+ }
+
+ private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate if supportsGroupingKeySchema(p) =>
+ val converted = p.transformExpressionsDown {
+ case expressions.Average(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Average(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Count(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ // We do not support multiple COUNT DISTINCT columns for now.
+ case expressions.CountDistinct(children) if children.length == 1 =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(children.head),
+ mode = aggregate.Complete,
+ isDistinct = true)
+
+ case expressions.First(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.First(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Last(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Last(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Max(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Max(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Min(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Min(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Sum(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.SumDistinct(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = true)
+ }
+ // Check if there is any expressions.AggregateExpression1 left.
+ // If so, we cannot convert this plan.
+ val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+ // For every expressions, check if it contains AggregateExpression1.
+ expr.find {
+ case agg: expressions.AggregateExpression1 => true
+ case other => false
+ }.isDefined
+ }
+
+ // Check if there are multiple distinct columns.
+ val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val hasMultipleDistinctColumnSets =
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ true
+ } else {
+ false
+ }
+
+ if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+ case other => None
+ }
+
+ def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+ // If the plan cannot be converted, we will do a final round check to see if the original
+ // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+ // we need to throw an exception.
+ val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg.aggregateFunction
+ }
+ }.distinct
+ if (aggregateFunction2s.nonEmpty) {
+ // For functions implemented based on the new interface, prepare a list of function names.
+ val invalidFunctions = {
+ if (aggregateFunction2s.length > 1) {
+ s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+ s"and ${aggregateFunction2s.head.nodeName} are"
+ } else {
+ s"${aggregateFunction2s.head.nodeName} is"
+ }
+ }
+ val errorMessage =
+ s"${invalidFunctions} implemented based on the new Aggregate Function " +
+ s"interface and it cannot be used with functions implemented based on " +
+ s"the old Aggregate Function interface."
+ throw new AnalysisException(errorMessage)
+ }
+ }
+
+ def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate =>
+ val converted = doConvert(p)
+ if (converted.isDefined) {
+ converted
+ } else {
+ checkInvalidAggregateFunction2(p)
+ None
+ }
+ case other => None
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ad5af19..a67f8de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -219,6 +220,8 @@ case class Aggregate(
expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
}
+ lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
+
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f3ef066..52a9b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -193,11 +193,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
- def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
- aggregate.Utils.tryConvert(
- plan,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
+ def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match {
+ case a: logical.Aggregate =>
+ if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) {
+ a.newAggregation.isDefined
+ } else {
+ Utils.checkInvalidAggregateFunction2(a)
+ false
+ }
+ case _ => false
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
@@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case p: logical.Aggregate =>
- val converted =
- aggregate.Utils.tryConvert(
- p,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled)
+ case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
+ sqlContext.conf.codegenEnabled =>
+ val converted = p.newAggregation
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
@@ -377,17 +378,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
- val useNewAggregation =
- aggregate.Utils.tryConvert(
- a,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
- if (useNewAggregation) {
+ val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
+ if (useNewAggregation && a.newAggregation.isDefined) {
// If this logical.Aggregate can be planned to use new aggregation code path
// (i.e. it can be planned by the Strategy Aggregation), we will not use the old
// aggregation code path.
Nil
} else {
+ Utils.checkInvalidAggregateFunction2(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 6549c87..03635ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
- val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
- case array: ArrayType => true
- case map: MapType => true
- case struct: StructType => true
- case _ => false
- }
-
- !hasComplexTypes
- }
-
- private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate if supportsGroupingKeySchema(p) =>
- val converted = p.transformExpressionsDown {
- case expressions.Average(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Average(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Count(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- // We do not support multiple COUNT DISTINCT columns for now.
- case expressions.CountDistinct(children) if children.length == 1 =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(children.head),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.First(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Last(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Max(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Max(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Min(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Min(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Sum(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.SumDistinct(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = true)
- }
- // Check if there is any expressions.AggregateExpression1 left.
- // If so, we cannot convert this plan.
- val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
- // For every expressions, check if it contains AggregateExpression1.
- expr.find {
- case agg: expressions.AggregateExpression1 => true
- case other => false
- }.isDefined
- }
-
- // Check if there are multiple distinct columns.
- val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.toSet.toSeq
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val hasMultipleDistinctColumnSets =
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- true
- } else {
- false
- }
-
- if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
- case other => None
- }
-
- private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
- // If the plan cannot be converted, we will do a final round check to if the original
- // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
- // we need to throw an exception.
- val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg.aggregateFunction
- }
- }.distinct
- if (aggregateFunction2s.nonEmpty) {
- // For functions implemented based on the new interface, prepare a list of function names.
- val invalidFunctions = {
- if (aggregateFunction2s.length > 1) {
- s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
- s"and ${aggregateFunction2s.head.nodeName} are"
- } else {
- s"${aggregateFunction2s.head.nodeName} is"
- }
- }
- val errorMessage =
- s"${invalidFunctions} implemented based on the new Aggregate Function " +
- s"interface and it cannot be used with functions implemented based on " +
- s"the old Aggregate Function interface."
- throw new AnalysisException(errorMessage)
- }
- }
-
- def tryConvert(
- plan: LogicalPlan,
- useNewAggregation: Boolean,
- codeGenEnabled: Boolean): Option[Aggregate] = plan match {
- case p: Aggregate if useNewAggregation && codeGenEnabled =>
- val converted = tryConvert(p)
- if (converted.isDefined) {
- converted
- } else {
- checkInvalidAggregateFunction2(p)
- None
- }
- case p: Aggregate =>
- checkInvalidAggregateFunction2(p)
- None
- case other => None
- }
-
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org