You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/08/27 05:10:32 UTC
spark git commit: [SPARK-17269][SQL] Move finish analysis
optimization stage into its own file
Repository: spark
Updated Branches:
refs/heads/master cc0caa690 -> dcefac438
[SPARK-17269][SQL] Move finish analysis optimization stage into its own file
## What changes were proposed in this pull request?
As part of breaking Optimizer.scala apart, this patch moves various finish analysis optimization stage rules into a single file. I'm submitting separate pull requests so we can more easily merge this in branch-2.0 to simplify optimizer backports.
## How was this patch tested?
This should be covered by existing tests.
Author: Reynold Xin <rx...@databricks.com>
Closes #14838 from rxin/SPARK-17269.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dcefac43
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dcefac43
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dcefac43
Branch: refs/heads/master
Commit: dcefac438788c51d84641bfbc505efe095731a39
Parents: cc0caa6
Author: Reynold Xin <rx...@databricks.com>
Authored: Fri Aug 26 22:10:28 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Aug 26 22:10:28 2016 -0700
----------------------------------------------------------------------
.../analysis/RewriteDistinctAggregates.scala | 269 -------------------
.../sql/catalyst/optimizer/Optimizer.scala | 38 ---
.../optimizer/RewriteDistinctAggregates.scala | 269 +++++++++++++++++++
.../sql/catalyst/optimizer/finishAnalysis.scala | 65 +++++
4 files changed, 334 insertions(+), 307 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/dcefac43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
deleted file mode 100644
index 8afd28d..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.analysis
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types.IntegerType
-
-/**
- * This rule rewrites an aggregate query with distinct aggregations into an expanded double
- * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
- * in a separate group. The results are then combined in a second aggregate.
- *
- * For example (in scala):
- * {{{
- * val data = Seq(
- * ("a", "ca1", "cb1", 10),
- * ("a", "ca1", "cb2", 5),
- * ("b", "ca1", "cb1", 13))
- * .toDF("key", "cat1", "cat2", "value")
- * data.createOrReplaceTempView("data")
- *
- * val agg = data.groupBy($"key")
- * .agg(
- * countDistinct($"cat1").as("cat1_cnt"),
- * countDistinct($"cat2").as("cat2_cnt"),
- * sum($"value").as("total"))
- * }}}
- *
- * This translates to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- * key = ['key]
- * functions = [COUNT(DISTINCT 'cat1),
- * COUNT(DISTINCT 'cat2),
- * sum('value)]
- * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- * LocalTableScan [...]
- * }}}
- *
- * This rule rewrites this logical plan to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- * key = ['key]
- * functions = [count(if (('gid = 1)) 'cat1 else null),
- * count(if (('gid = 2)) 'cat2 else null),
- * first(if (('gid = 0)) 'total else null) ignore nulls]
- * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- * Aggregate(
- * key = ['key, 'cat1, 'cat2, 'gid]
- * functions = [sum('value)]
- * output = ['key, 'cat1, 'cat2, 'gid, 'total])
- * Expand(
- * projections = [('key, null, null, 0, cast('value as bigint)),
- * ('key, 'cat1, null, 1, null),
- * ('key, null, 'cat2, 2, null)]
- * output = ['key, 'cat1, 'cat2, 'gid, 'value])
- * LocalTableScan [...]
- * }}}
- *
- * The rule does the following things here:
- * 1. Expand the data. There are three aggregation groups in this query:
- * i. the non-distinct group;
- * ii. the distinct 'cat1 group;
- * iii. the distinct 'cat2 group.
- * An expand operator is inserted to expand the child data for each group. The expand will null
- * out all unused columns for the given group; this must be done in order to ensure correctness
- * later on. Groups can by identified by a group id (gid) column added by the expand operator.
- * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
- * this aggregate consists of the original group by clause, all the requested distinct columns
- * and the group id. Both de-duplication of distinct column and the aggregation of the
- * non-distinct group take advantage of the fact that we group by the group id (gid) and that we
- * have nulled out all non-relevant columns the given group.
- * 3. Aggregating the distinct groups and combining this with the results of the non-distinct
- * aggregation. In this step we use the group id to filter the inputs for the aggregate
- * functions. The result of the non-distinct group are 'aggregated' by using the first operator,
- * it might be more elegant to use the native UDAF merge mechanism for this in the future.
- *
- * This rule duplicates the input data by two or more times (# distinct groups + an optional
- * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
- * exchange operators. Keeping the number of distinct groups as low a possible should be priority,
- * we could improve this in the current rule by applying more advanced expression canonicalization
- * techniques.
- */
-object RewriteDistinctAggregates extends Rule[LogicalPlan] {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case a: Aggregate => rewrite(a)
- }
-
- def rewrite(a: Aggregate): Aggregate = {
-
- // Collect all aggregate expressions.
- val aggExpressions = a.aggregateExpressions.flatMap { e =>
- e.collect {
- case ae: AggregateExpression => ae
- }
- }
-
- // Extract distinct aggregate expressions.
- val distinctAggGroups = aggExpressions
- .filter(_.isDistinct)
- .groupBy(_.aggregateFunction.children.toSet)
-
- // Aggregation strategy can handle the query with single distinct
- if (distinctAggGroups.size > 1) {
- // Create the attributes for the grouping id and the group by clause.
- val gid =
- new AttributeReference("gid", IntegerType, false)(isGenerated = true)
- val groupByMap = a.groupingExpressions.collect {
- case ne: NamedExpression => ne -> ne.toAttribute
- case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
- }
- val groupByAttrs = groupByMap.map(_._2)
-
- // Functions used to modify aggregate functions and their inputs.
- def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
- def patchAggregateFunctionChildren(
- af: AggregateFunction)(
- attrs: Expression => Expression): AggregateFunction = {
- af.withNewChildren(af.children.map {
- case afc => attrs(afc)
- }).asInstanceOf[AggregateFunction]
- }
-
- // Setup unique distinct aggregate children.
- val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
- val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
- val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
-
- // Setup expand & aggregate operators for distinct aggregate expressions.
- val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
- val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
- case ((group, expressions), i) =>
- val id = Literal(i + 1)
-
- // Expand projection
- val projection = distinctAggChildren.map {
- case e if group.contains(e) => e
- case e => nullify(e)
- } :+ id
-
- // Final aggregate
- val operators = expressions.map { e =>
- val af = e.aggregateFunction
- val naf = patchAggregateFunctionChildren(af) { x =>
- evalWithinGroup(id, distinctAggChildAttrLookup(x))
- }
- (e, e.copy(aggregateFunction = naf, isDistinct = false))
- }
-
- (projection, operators)
- }
-
- // Setup expand for the 'regular' aggregate expressions.
- val regularAggExprs = aggExpressions.filter(!_.isDistinct)
- val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
- val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
-
- // Setup aggregates for 'regular' aggregate expressions.
- val regularGroupId = Literal(0)
- val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
- val regularAggOperatorMap = regularAggExprs.map { e =>
- // Perform the actual aggregation in the initial aggregate.
- val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
- val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
-
- // Select the result of the first aggregate in the last aggregate.
- val result = AggregateExpression(
- aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
- mode = Complete,
- isDistinct = false)
-
- // Some aggregate functions (COUNT) have the special property that they can return a
- // non-null result without any input. We need to make sure we return a result in this case.
- val resultWithDefault = af.defaultResult match {
- case Some(lit) => Coalesce(Seq(result, lit))
- case None => result
- }
-
- // Return a Tuple3 containing:
- // i. The original aggregate expression (used for look ups).
- // ii. The actual aggregation operator (used in the first aggregate).
- // iii. The operator that selects and returns the result (used in the second aggregate).
- (e, operator, resultWithDefault)
- }
-
- // Construct the regular aggregate input projection only if we need one.
- val regularAggProjection = if (regularAggExprs.nonEmpty) {
- Seq(a.groupingExpressions ++
- distinctAggChildren.map(nullify) ++
- Seq(regularGroupId) ++
- regularAggChildren)
- } else {
- Seq.empty[Seq[Expression]]
- }
-
- // Construct the distinct aggregate input projections.
- val regularAggNulls = regularAggChildren.map(nullify)
- val distinctAggProjections = distinctAggOperatorMap.map {
- case (projection, _) =>
- a.groupingExpressions ++
- projection ++
- regularAggNulls
- }
-
- // Construct the expand operator.
- val expand = Expand(
- regularAggProjection ++ distinctAggProjections,
- groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
- a.child)
-
- // Construct the first aggregate operator. This de-duplicates the all the children of
- // distinct operators, and applies the regular aggregate operators.
- val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
- val firstAggregate = Aggregate(
- firstAggregateGroupBy,
- firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
- expand)
-
- // Construct the second aggregate
- val transformations: Map[Expression, Expression] =
- (distinctAggOperatorMap.flatMap(_._2) ++
- regularAggOperatorMap.map(e => (e._1, e._3))).toMap
-
- val patchedAggExpressions = a.aggregateExpressions.map { e =>
- e.transformDown {
- case e: Expression =>
- // The same GROUP BY clauses can have different forms (different names for instance) in
- // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
- // tricky. So we do a linear search for a semantically equal group by expression.
- groupByMap
- .find(ge => e.semanticEquals(ge._1))
- .map(_._2)
- .getOrElse(transformations.getOrElse(e, e))
- }.asInstanceOf[NamedExpression]
- }
- Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
- } else {
- a
- }
- }
-
- private def nullify(e: Expression) = Literal.create(null, e.dataType)
-
- private def expressionAttributePair(e: Expression) =
- // We are creating a new reference here instead of reusing the attribute in case of a
- // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
- // children, in this case attribute reuse causes the input of the regular aggregate to bound to
- // the (nulled out) input of the distinct aggregate.
- e -> new AttributeReference(e.sql, e.dataType, true)()
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/dcefac43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7bbcd74..d055bc3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1639,44 +1639,6 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
}
/**
- * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
- * be evaluated. This is mainly used to provide compatibility with other databases.
- * For example, we use this to support "nvl" by replacing it with "coalesce".
- */
-object ReplaceExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case e: RuntimeReplaceable => e.replaced
- }
-}
-
-/**
- * Computes the current date and time to make sure we return the same result in a single query.
- */
-object ComputeCurrentTime extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- val dateExpr = CurrentDate()
- val timeExpr = CurrentTimestamp()
- val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
- val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
-
- plan transformAllExpressions {
- case CurrentDate() => currentDate
- case CurrentTimestamp() => currentTime
- }
- }
-}
-
-/** Replaces the expression of CurrentDatabase with the current database name. */
-case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan transformAllExpressions {
- case CurrentDatabase() =>
- Literal.create(sessionCatalog.getCurrentDatabase, StringType)
- }
- }
-}
-
-/**
* This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
* are supported:
* a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
http://git-wip-us.apache.org/repos/asf/spark/blob/dcefac43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
new file mode 100644
index 0000000..0f43e7b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * This rule rewrites an aggregate query with distinct aggregations into an expanded double
+ * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * For example (in scala):
+ * {{{
+ * val data = Seq(
+ * ("a", "ca1", "cb1", 10),
+ * ("a", "ca1", "cb2", 5),
+ * ("b", "ca1", "cb1", 13))
+ * .toDF("key", "cat1", "cat2", "value")
+ * data.createOrReplaceTempView("data")
+ *
+ * val agg = data.groupBy($"key")
+ * .agg(
+ * countDistinct($"cat1").as("cat1_cnt"),
+ * countDistinct($"cat2").as("cat2_cnt"),
+ * sum($"value").as("total"))
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [COUNT(DISTINCT 'cat1),
+ * COUNT(DISTINCT 'cat2),
+ * sum('value)]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [count(if (('gid = 1)) 'cat1 else null),
+ * count(if (('gid = 2)) 'cat2 else null),
+ * first(if (('gid = 0)) 'total else null) ignore nulls]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * Aggregate(
+ * key = ['key, 'cat1, 'cat2, 'gid]
+ * functions = [sum('value)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ * Expand(
+ * projections = [('key, null, null, 0, cast('value as bigint)),
+ * ('key, 'cat1, null, 1, null),
+ * ('key, null, 'cat2, 2, null)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'value])
+ * LocalTableScan [...]
+ * }}}
+ *
+ * The rule does the following things here:
+ * 1. Expand the data. There are three aggregation groups in this query:
+ * i. the non-distinct group;
+ * ii. the distinct 'cat1 group;
+ * iii. the distinct 'cat2 group.
+ * An expand operator is inserted to expand the child data for each group. The expand will null
+ * out all unused columns for the given group; this must be done in order to ensure correctness
+ * later on. Groups can by identified by a group id (gid) column added by the expand operator.
+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
+ * this aggregate consists of the original group by clause, all the requested distinct columns
+ * and the group id. Both de-duplication of distinct column and the aggregation of the
+ * non-distinct group take advantage of the fact that we group by the group id (gid) and that we
+ * have nulled out all non-relevant columns the given group.
+ * 3. Aggregating the distinct groups and combining this with the results of the non-distinct
+ * aggregation. In this step we use the group id to filter the inputs for the aggregate
+ * functions. The result of the non-distinct group are 'aggregated' by using the first operator,
+ * it might be more elegant to use the native UDAF merge mechanism for this in the future.
+ *
+ * This rule duplicates the input data by two or more times (# distinct groups + an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
+ * exchange operators. Keeping the number of distinct groups as low a possible should be priority,
+ * we could improve this in the current rule by applying more advanced expression canonicalization
+ * techniques.
+ */
+object RewriteDistinctAggregates extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case a: Aggregate => rewrite(a)
+ }
+
+ def rewrite(a: Aggregate): Aggregate = {
+
+ // Collect all aggregate expressions.
+ val aggExpressions = a.aggregateExpressions.flatMap { e =>
+ e.collect {
+ case ae: AggregateExpression => ae
+ }
+ }
+
+ // Extract distinct aggregate expressions.
+ val distinctAggGroups = aggExpressions
+ .filter(_.isDistinct)
+ .groupBy(_.aggregateFunction.children.toSet)
+
+ // Aggregation strategy can handle the query with single distinct
+ if (distinctAggGroups.size > 1) {
+ // Create the attributes for the grouping id and the group by clause.
+ val gid =
+ new AttributeReference("gid", IntegerType, false)(isGenerated = true)
+ val groupByMap = a.groupingExpressions.collect {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
+ }
+ val groupByAttrs = groupByMap.map(_._2)
+
+ // Functions used to modify aggregate functions and their inputs.
+ def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
+ def patchAggregateFunctionChildren(
+ af: AggregateFunction)(
+ attrs: Expression => Expression): AggregateFunction = {
+ af.withNewChildren(af.children.map {
+ case afc => attrs(afc)
+ }).asInstanceOf[AggregateFunction]
+ }
+
+ // Setup unique distinct aggregate children.
+ val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
+ val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
+ val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
+
+ // Setup expand & aggregate operators for distinct aggregate expressions.
+ val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
+ val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+ case ((group, expressions), i) =>
+ val id = Literal(i + 1)
+
+ // Expand projection
+ val projection = distinctAggChildren.map {
+ case e if group.contains(e) => e
+ case e => nullify(e)
+ } :+ id
+
+ // Final aggregate
+ val operators = expressions.map { e =>
+ val af = e.aggregateFunction
+ val naf = patchAggregateFunctionChildren(af) { x =>
+ evalWithinGroup(id, distinctAggChildAttrLookup(x))
+ }
+ (e, e.copy(aggregateFunction = naf, isDistinct = false))
+ }
+
+ (projection, operators)
+ }
+
+ // Setup expand for the 'regular' aggregate expressions.
+ val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+ val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+ val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
+
+ // Setup aggregates for 'regular' aggregate expressions.
+ val regularGroupId = Literal(0)
+ val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
+ val regularAggOperatorMap = regularAggExprs.map { e =>
+ // Perform the actual aggregation in the initial aggregate.
+ val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+ val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
+
+ // Select the result of the first aggregate in the last aggregate.
+ val result = AggregateExpression(
+ aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
+ mode = Complete,
+ isDistinct = false)
+
+ // Some aggregate functions (COUNT) have the special property that they can return a
+ // non-null result without any input. We need to make sure we return a result in this case.
+ val resultWithDefault = af.defaultResult match {
+ case Some(lit) => Coalesce(Seq(result, lit))
+ case None => result
+ }
+
+ // Return a Tuple3 containing:
+ // i. The original aggregate expression (used for look ups).
+ // ii. The actual aggregation operator (used in the first aggregate).
+ // iii. The operator that selects and returns the result (used in the second aggregate).
+ (e, operator, resultWithDefault)
+ }
+
+ // Construct the regular aggregate input projection only if we need one.
+ val regularAggProjection = if (regularAggExprs.nonEmpty) {
+ Seq(a.groupingExpressions ++
+ distinctAggChildren.map(nullify) ++
+ Seq(regularGroupId) ++
+ regularAggChildren)
+ } else {
+ Seq.empty[Seq[Expression]]
+ }
+
+ // Construct the distinct aggregate input projections.
+ val regularAggNulls = regularAggChildren.map(nullify)
+ val distinctAggProjections = distinctAggOperatorMap.map {
+ case (projection, _) =>
+ a.groupingExpressions ++
+ projection ++
+ regularAggNulls
+ }
+
+ // Construct the expand operator.
+ val expand = Expand(
+ regularAggProjection ++ distinctAggProjections,
+ groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
+ a.child)
+
+ // Construct the first aggregate operator. This de-duplicates the all the children of
+ // distinct operators, and applies the regular aggregate operators.
+ val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+ val firstAggregate = Aggregate(
+ firstAggregateGroupBy,
+ firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+ expand)
+
+ // Construct the second aggregate
+ val transformations: Map[Expression, Expression] =
+ (distinctAggOperatorMap.flatMap(_._2) ++
+ regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+ val patchedAggExpressions = a.aggregateExpressions.map { e =>
+ e.transformDown {
+ case e: Expression =>
+ // The same GROUP BY clauses can have different forms (different names for instance) in
+ // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
+ // tricky. So we do a linear search for a semantically equal group by expression.
+ groupByMap
+ .find(ge => e.semanticEquals(ge._1))
+ .map(_._2)
+ .getOrElse(transformations.getOrElse(e, e))
+ }.asInstanceOf[NamedExpression]
+ }
+ Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+ } else {
+ a
+ }
+ }
+
+ private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+ private def expressionAttributePair(e: Expression) =
+ // We are creating a new reference here instead of reusing the attribute in case of a
+ // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
+ // children, in this case attribute reuse causes the input of the regular aggregate to bound to
+ // the (nulled out) input of the distinct aggregate.
+ e -> new AttributeReference(e.sql, e.dataType, true)()
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/dcefac43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
new file mode 100644
index 0000000..7c66731
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types._
+
+
+/**
+ * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
+ * be evaluated. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+object ReplaceExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case e: RuntimeReplaceable => e.replaced
+ }
+}
+
+
+/**
+ * Computes the current date and time to make sure we return the same result in a single query.
+ */
+object ComputeCurrentTime extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val dateExpr = CurrentDate()
+ val timeExpr = CurrentTimestamp()
+ val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
+ val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
+
+ plan transformAllExpressions {
+ case CurrentDate() => currentDate
+ case CurrentTimestamp() => currentTime
+ }
+ }
+}
+
+
+/** Replaces the expression of CurrentDatabase with the current database name. */
+case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan transformAllExpressions {
+ case CurrentDatabase() =>
+ Literal.create(sessionCatalog.getCurrentDatabase, StringType)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org