You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/06/15 21:01:57 UTC

spark git commit: [SPARK-6583] [SQL] Support aggregate functions in ORDER BY

Repository: spark
Updated Branches:
  refs/heads/master 56d4e8a2d -> 6ae21a944


[SPARK-6583] [SQL] Support aggregate functions in ORDER BY

Add aggregates in ORDER BY clauses to the `Aggregate` operator beneath.  Project these results away after the Sort.

Based on work by watermen.  Also Closes #5290.

Author: Yadong Qi <qi...@gmail.com>
Author: Michael Armbrust <mi...@databricks.com>

Closes #6816 from marmbrus/pr/5290 and squashes the following commits:

3226a97 [Michael Armbrust] consistent ordering
eb8938d [Michael Armbrust] no vars
c8b25c1 [Yadong Qi] move the test data.
7f9b736 [Yadong Qi] delete Substring case
a1e87c1 [Yadong Qi] fix conflict
f119849 [Yadong Qi] order by aggregated function


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ae21a94
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ae21a94
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ae21a94

Branch: refs/heads/master
Commit: 6ae21a944a0f4580b55749776223c827450b00da
Parents: 56d4e8a
Author: Yadong Qi <qi...@gmail.com>
Authored: Mon Jun 15 12:01:52 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Jun 15 12:01:52 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 19 +++++++--
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 45 ++++++++++++++++++++
 2 files changed, 61 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ae21a94/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4b7fef7..badf903 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.types._
+import scala.collection.mutable.ArrayBuffer
 
 /**
  * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -396,19 +397,31 @@ class Analyzer(
         }
       case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
           if !s.resolved && a.resolved =>
-        val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
         // A small hack to create an object that will allow us to resolve any references that
         // refer to named expressions that are present in the grouping expressions.
         val groupingRelation = LocalRelation(
           grouping.collect { case ne: NamedExpression => ne.toAttribute }
         )
 
-        val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation)
+        // Find sort attributes that are projected away so we can temporarily add them back in.
+        val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation)
+
+        // Find aggregate expressions and evaluate them early, since they can't be evaluated in a
+        // Sort.
+        val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map {
+          case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
+            val aliased = Alias(aggOrdering.child, "_aggOrdering")()
+            (aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil)
+
+          case other => (other, Nil)
+        }.unzip
+
+        val missing = unresolved ++ aliasedAggregateList.flatten
 
         if (missing.nonEmpty) {
           // Add missing grouping exprs and then project them away after the sort.
           Project(a.output,
-            Sort(resolvedOrdering, global,
+            Sort(withAggsRemoved, global,
               Aggregate(grouping, aggs ++ missing, child)))
         } else {
           s // Nothing we can do here. Return original plan.

http://git-wip-us.apache.org/repos/asf/spark/blob/6ae21a94/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d1520b7..a47cc30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1366,6 +1366,51 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
   }
 
+  test("SPARK-6583 order by aggregated function") {
+    Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2)
+      .toDF("a", "b").registerTempTable("orderByData")
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT a
+          |FROM orderByData
+          |GROUP BY a
+          |ORDER BY sum(b)
+        """.stripMargin),
+      Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT sum(b)
+          |FROM orderByData
+          |GROUP BY a
+          |ORDER BY sum(b)
+        """.stripMargin),
+      Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT a, sum(b)
+          |FROM orderByData
+          |GROUP BY a
+          |ORDER BY sum(b)
+        """.stripMargin),
+      Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+            |SELECT a, sum(b)
+            |FROM orderByData
+            |GROUP BY a
+            |ORDER BY sum(b) + 1
+          """.stripMargin),
+      Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil)
+  }
+
   test("SPARK-7952: fix the equality check between boolean and numeric types") {
     withTempTable("t") {
       // numeric field i, boolean field j, result of i = j, result of i <=> j


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org