You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/09/02 20:32:32 UTC

spark git commit: [SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate

Repository: spark
Updated Branches:
  refs/heads/master 56c4c172e -> fc4830779


[SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate

For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL.

Author: Wenchen Fan <cl...@outlook.com>

Closes #8548 from cloud-fan/support-order-by-non-attribute.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fc483077
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fc483077
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fc483077

Branch: refs/heads/master
Commit: fc48307797912dc1d53893dce741ddda8630957b
Parents: 56c4c17
Author: Wenchen Fan <cl...@outlook.com>
Authored: Wed Sep 2 11:32:27 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Sep 2 11:32:27 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 72 ++++++++++----------
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 19 ++++--
 2 files changed, 52 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fc483077/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1a5de15..591747b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -560,43 +560,47 @@ class Analyzer(
           filter
         }
 
-      case sort @ Sort(sortOrder, global,
-             aggregate @ Aggregate(grouping, originalAggExprs, child))
+      case sort @ Sort(sortOrder, global, aggregate: Aggregate)
         if aggregate.resolved && !sort.resolved =>
 
         // Try resolving the ordering as though it is in the aggregate clause.
         try {
-          val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
-          val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
-          val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
-          def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
-
-          // Expressions that have an aggregate can be pushed down.
-          val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
-
-          // Attribute references, that are missing from the order but are present in the grouping
-          // expressions can also be pushed down.
-          val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
-          val missingAttributes = requiredAttributes -- aggregate.outputSet
-          val validPushdownAttributes =
-            missingAttributes.filter(a => grouping.exists(a.semanticEquals))
-
-          // If resolution was successful and we see the ordering either has an aggregate in it or
-          // it is missing something that is projected away by the aggregate, add the ordering
-          // the original aggregate operator.
-          if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
-            val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
-              case (order, evaluated) => order.copy(child = evaluated.toAttribute)
-            }
-            val aggExprsWithOrdering: Seq[NamedExpression] =
-              resolvedAggregateOrdering ++ originalAggExprs
-
-            Project(aggregate.output,
-              Sort(evaluatedOrderings, global,
-                aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
-          } else {
-            sort
+          val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
+          val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
+          val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
+          val resolvedAliasedOrdering: Seq[Alias] =
+            resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
+
+          // If we pass the analysis check, then the ordering expressions should only reference to
+          // aggregate expressions or grouping expressions, and it's safe to push them down to
+          // Aggregate.
+          checkAnalysis(resolvedAggregate)
+
+          val originalAggExprs = aggregate.aggregateExpressions.map(
+            CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+
+          // If the ordering expression is same with original aggregate expression, we don't need
+          // to push down this ordering expression and can reference the original aggregate
+          // expression instead.
+          val needsPushDown = ArrayBuffer.empty[NamedExpression]
+          val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
+            case (evaluated, order) =>
+              val index = originalAggExprs.indexWhere {
+                case Alias(child, _) => child semanticEquals evaluated.child
+                case other => other semanticEquals evaluated.child
+              }
+
+              if (index == -1) {
+                needsPushDown += evaluated
+                order.copy(child = evaluated.toAttribute)
+              } else {
+                order.copy(child = originalAggExprs(index).toAttribute)
+              }
           }
+
+          Project(aggregate.output,
+            Sort(evaluatedOrderings, global,
+              aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
         } catch {
           // Attempting to resolve in the aggregate can result in ambiguity.  When this happens,
           // just return the original plan.
@@ -605,9 +609,7 @@ class Analyzer(
     }
 
     protected def containsAggregate(condition: Expression): Boolean = {
-      condition
-        .collect { case ae: AggregateExpression => ae }
-        .nonEmpty
+      condition.find(_.isInstanceOf[AggregateExpression]).isDefined
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fc483077/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 2820107..0ef25fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1722,9 +1722,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-10130 type coercion for IF should have children resolved first") {
-    val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
-    df.registerTempTable("src")
-    checkAnswer(
-      sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
+    withTempTable("src") {
+      Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
+      checkAnswer(
+        sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
+    }
+  }
+
+  test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
+    withTempTable("src") {
+      Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
+      checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
+        Seq(Row(1), Row(1)))
+      checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
+        Seq(Row(1), Row(1)))
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org