You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/18 12:59:46 UTC

spark git commit: [SPARK-22266][SQL] The same aggregate function was evaluated multiple times

Repository: spark
Updated Branches:
  refs/heads/master f3137feec -> 72561ecf4


[SPARK-22266][SQL] The same aggregate function was evaluated multiple times

## What changes were proposed in this pull request?

To let the same aggregate function that appear multiple times in an Aggregate be evaluated only once, we need to deduplicate the aggregate expressions. The original code was trying to use a "distinct" call to get a set of aggregate expressions, but did not work, since the "distinct" did not compare semantic equality. And even if it did, further work should be done in result expression rewriting.
In this PR, I changed the "set" to a map mapping the semantic identity of a aggregate expression to itself. Thus, later on, when rewriting result expressions (i.e., output expressions), the aggregate expression reference can be fixed.

## How was this patch tested?

Added a new test in SQLQuerySuite

Author: maryannxue <ma...@gmail.com>

Closes #19488 from maryannxue/spark-22266.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72561ecf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72561ecf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72561ecf

Branch: refs/heads/master
Commit: 72561ecf4b611d68f8bf695ddd0c4c2cce3a29d9
Parents: f3137fe
Author: maryannxue <ma...@gmail.com>
Authored: Wed Oct 18 20:59:40 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Oct 18 20:59:40 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/planning/patterns.scala  | 16 +++++++-----
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 26 ++++++++++++++++++++
 2 files changed, 36 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/72561ecf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 8d034c2..cc391aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -205,14 +205,17 @@ object PhysicalAggregation {
     case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
       // A single aggregate expression might appear multiple times in resultExpressions.
       // In order to avoid evaluating an individual aggregate function multiple times, we'll
-      // build a set of the distinct aggregate expressions and build a function which can
-      // be used to re-write expressions so that they reference the single copy of the
-      // aggregate function which actually gets computed.
+      // build a set of semantically distinct aggregate expressions and re-write expressions so
+      // that they reference the single copy of the aggregate function which actually gets computed.
+      // Non-deterministic aggregate expressions are not deduplicated.
+      val equivalentAggregateExpressions = new EquivalentExpressions
       val aggregateExpressions = resultExpressions.flatMap { expr =>
         expr.collect {
-          case agg: AggregateExpression => agg
+          // addExpr() always returns false for non-deterministic expressions and do not add them.
+          case agg: AggregateExpression
+            if (!equivalentAggregateExpressions.addExpr(agg)) => agg
         }
-      }.distinct
+      }
 
       val namedGroupingExpressions = groupingExpressions.map {
         case ne: NamedExpression => ne -> ne
@@ -236,7 +239,8 @@ object PhysicalAggregation {
           case ae: AggregateExpression =>
             // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
             // so replace each aggregate expression by its corresponding attribute in the set:
-            ae.resultAttribute
+            equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
+              .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
           case expression =>
             // Since we're using `namedGroupingAttributes` to extract the grouping key
             // columns, we need to replace grouping key expressions with their corresponding

http://git-wip-us.apache.org/repos/asf/spark/blob/72561ecf/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f0c58e2..caf332d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException}
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -2715,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       checkAnswer(df, Row(1, 1, 1))
     }
   }
+
+  test("SRARK-22266: the same aggregate function was calculated multiple times") {
+    val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a"
+    val df = sql(query)
+    val physical = df.queryExecution.sparkPlan
+    val aggregateExpressions = physical.collectFirst {
+      case agg : HashAggregateExec => agg.aggregateExpressions
+      case agg : SortAggregateExec => agg.aggregateExpressions
+    }
+    assert (aggregateExpressions.isDefined)
+    assert (aggregateExpressions.get.size == 1)
+    checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil)
+  }
+
+  test("Non-deterministic aggregate functions should not be deduplicated") {
+    val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a"
+    val df = sql(query)
+    val physical = df.queryExecution.sparkPlan
+    val aggregateExpressions = physical.collectFirst {
+      case agg : HashAggregateExec => agg.aggregateExpressions
+      case agg : SortAggregateExec => agg.aggregateExpressions
+    }
+    assert (aggregateExpressions.isDefined)
+    assert (aggregateExpressions.get.size == 2)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org