You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/02/01 20:57:21 UTC

spark git commit: [SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences

Repository: spark
Updated Branches:
  refs/heads/master 33c8a490f -> 8f26eb5ef


[SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences

JIRA: https://issues.apache.org/jira/browse/SPARK-12705

**Scope:**
This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*):
  - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`.
  - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it.

**General Reference Resolution Rules:**
  - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children.
  - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed.
  - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries.
  - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`.

**Implementation:**
  1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes.
  2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes.
  3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node.

**Risk:**
Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected.

Author: gatorsmile <ga...@gmail.com>

Closes #10678 from gatorsmile/sortWindows.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f26eb5e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f26eb5e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f26eb5e

Branch: refs/heads/master
Commit: 8f26eb5ef6853a6666d7d9481b333de70bc501ed
Parents: 33c8a49
Author: gatorsmile <ga...@gmail.com>
Authored: Mon Feb 1 11:57:13 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Feb 1 11:57:13 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 101 +++++++++++++++----
 .../sql/catalyst/analysis/AnalysisSuite.scala   |  83 +++++++++++++++
 .../sql/catalyst/analysis/TestRelations.scala   |   6 ++
 .../apache/spark/sql/DataFrameJoinSuite.scala   |  16 +++
 .../org/apache/spark/sql/DataFrameSuite.scala   |   6 ++
 .../sql/hive/execution/SQLQuerySuite.scala      |  84 ++++++++++++++-
 6 files changed, 274 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ee60fca..a983dc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.annotation.tailrec
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
@@ -452,7 +453,7 @@ class Analyzer(
         i.copy(right = dedupRight(left, right))
 
       // When resolve `SortOrder`s in Sort based on child, don't report errors as
-      // we still have chance to resolve it based on grandchild
+      // we still have chance to resolve it based on its descendants
       case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
         val newOrdering = resolveSortOrders(ordering, child, throws = false)
         Sort(newOrdering, global, child)
@@ -533,38 +534,96 @@ class Analyzer(
    */
   object ResolveSortReferences extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case s @ Sort(ordering, global, p @ Project(projectList, child))
-          if !s.resolved && p.resolved =>
-        val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
+      // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
+      case sa @ Sort(_, _, child: Aggregate) => sa
 
-        // If this rule was not a no-op, return the transformed plan, otherwise return the original.
-        if (missing.nonEmpty) {
-          // Add missing attributes and then project them away after the sort.
-          Project(p.output,
-            Sort(newOrdering, global,
-              Project(projectList ++ missing, child)))
-        } else {
-          logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
+      case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
+        val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
+
+        if (missingResolvableAttrs.isEmpty) {
+          val unresolvableAttrs = s.order.filterNot(_.resolved)
+          logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
           s // Nothing we can do here. Return original plan.
+        } else {
+          // Add the missing attributes into projectList of Project/Window or
+          //   aggregateExpressions of Aggregate, if they are in the inputSet
+          //   but not in the outputSet of the plan.
+          val newChild = child transformUp {
+            case p: Project =>
+              p.copy(projectList = p.projectList ++
+                missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
+            case w: Window =>
+              w.copy(projectList = w.projectList ++
+                missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
+            case a: Aggregate =>
+              val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
+              val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
+              val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
+              a.copy(aggregateExpressions = newAggregateExpressions)
+            case o => o
+          }
+
+          // Add missing attributes and then project them away after the sort.
+          Project(child.output,
+            Sort(newOrdering, s.global, newChild))
         }
     }
 
     /**
-     * Given a child and a grandchild that are present beneath a sort operator, try to resolve
-     * the sort ordering and returns it with a list of attributes that are missing from the
-     * child but are present in the grandchild.
+     * Traverse the tree until resolving the sorting attributes
+     * Return all the resolvable missing sorting attributes
+     */
+    @tailrec
+    private def collectResolvableMissingAttrs(
+        ordering: Seq[SortOrder],
+        plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+      plan match {
+        // Only Windows and Project have projectList-like attribute.
+        case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
+          val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
+          // If missingAttrs is non empty, that means we got it and return it;
+          // Otherwise, continue to traverse the tree.
+          if (missingAttrs.nonEmpty) {
+            (newOrdering, missingAttrs)
+          } else {
+            collectResolvableMissingAttrs(ordering, un.child)
+          }
+        case a: Aggregate =>
+          val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
+          // For Aggregate, all the order by columns must be specified in group by clauses
+          if (missingAttrs.nonEmpty &&
+              missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
+            (newOrdering, missingAttrs)
+          } else {
+            // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
+            (Seq.empty[SortOrder], Seq.empty[Attribute])
+          }
+        // Jump over the following UnaryNode types
+        // The output of these types is the same as their child's output
+        case _: Distinct |
+             _: Filter |
+             _: RepartitionByExpression =>
+          collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
+        // If hitting the other unsupported operators, we are unable to resolve it.
+        case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
+      }
+    }
+
+    /**
+     * Try to resolve the sort ordering and returns it with a list of attributes that are missing
+     * from the plan but are present in the child.
      */
-    def resolveAndFindMissing(
+    private def resolveAndFindMissing(
         ordering: Seq[SortOrder],
-        child: LogicalPlan,
-        grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
-      val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
+        plan: LogicalPlan,
+        child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+      val newOrdering = resolveSortOrders(ordering, child, throws = false)
       // Construct a set that contains all of the attributes that we need to evaluate the
       // ordering.
       val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
       // Figure out which ones are missing from the projection, so that we can add them and
       // remove them after the sort.
-      val missingInProject = requiredAttributes -- child.output
+      val missingInProject = requiredAttributes -- plan.outputSet
       // It is important to return the new SortOrders here, instead of waiting for the standard
       // resolving process as adding attributes to the project below can actually introduce
       // ambiguity that was not present before.
@@ -719,7 +778,7 @@ class Analyzer(
         }
     }
 
-    protected def containsAggregate(condition: Expression): Boolean = {
+    def containsAggregate(condition: Expression): Boolean = {
       condition.find(_.isInstanceOf[AggregateExpression]).isDefined
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 1938bce..ebf885a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest {
       caseSensitive = false)
   }
 
+  test("resolve sort references - filter/limit") {
+    val a = testRelation2.output(0)
+    val b = testRelation2.output(1)
+    val c = testRelation2.output(2)
+
+    // Case 1: one missing attribute is in the leaf node and another is in the unary node
+    val plan1 = testRelation2
+      .where('a > "str").select('a, 'b)
+      .where('b > "str").select('a)
+      .sortBy('b.asc, 'c.desc)
+    val expected1 = testRelation2
+      .where(a > "str").select(a, b, c)
+      .where(b > "str").select(a, b, c)
+      .sortBy(b.asc, c.desc)
+      .select(a, b).select(a)
+    checkAnalysis(plan1, expected1)
+
+    // Case 2: all the missing attributes are in the leaf node
+    val plan2 = testRelation2
+      .where('a > "str").select('a)
+      .where('a > "str").select('a)
+      .sortBy('b.asc, 'c.desc)
+    val expected2 = testRelation2
+      .where(a > "str").select(a, b, c)
+      .where(a > "str").select(a, b, c)
+      .sortBy(b.asc, c.desc)
+      .select(a)
+    checkAnalysis(plan2, expected2)
+  }
+
+  test("resolve sort references - join") {
+    val a = testRelation2.output(0)
+    val b = testRelation2.output(1)
+    val c = testRelation2.output(2)
+    val h = testRelation3.output(3)
+
+    // Case: join itself can resolve all the missing attributes
+    val plan = testRelation2.join(testRelation3)
+      .where('a > "str").select('a, 'b)
+      .sortBy('c.desc, 'h.asc)
+    val expected = testRelation2.join(testRelation3)
+      .where(a > "str").select(a, b, c, h)
+      .sortBy(c.desc, h.asc)
+      .select(a, b)
+    checkAnalysis(plan, expected)
+  }
+
+  test("resolve sort references - aggregate") {
+    val a = testRelation2.output(0)
+    val b = testRelation2.output(1)
+    val c = testRelation2.output(2)
+    val alias_a3 = count(a).as("a3")
+    val alias_b = b.as("aggOrder")
+
+    // Case 1: when the child of Sort is not Aggregate,
+    //   the sort reference is handled by the rule ResolveSortReferences
+    val plan1 = testRelation2
+      .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+      .select('a, 'c, 'a3)
+      .orderBy('b.asc)
+
+    val expected1 = testRelation2
+      .groupBy(a, c, b)(a, c, alias_a3, b)
+      .select(a, c, alias_a3.toAttribute, b)
+      .orderBy(b.asc)
+      .select(a, c, alias_a3.toAttribute)
+
+    checkAnalysis(plan1, expected1)
+
+    // Case 2: when the child of Sort is Aggregate,
+    //   the sort reference is handled by the rule ResolveAggregateFunctions
+    val plan2 = testRelation2
+      .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+      .orderBy('b.asc)
+
+    val expected2 = testRelation2
+      .groupBy(a, c, b)(a, c, alias_a3, alias_b)
+      .orderBy(alias_b.toAttribute.asc)
+      .select(a, c, alias_a3.toAttribute)
+
+    checkAnalysis(plan2, expected2)
+  }
+
   test("resolve relations") {
     assertAnalysisError(
       UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe"))

http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index bc07b60..3741a6b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -31,6 +31,12 @@ object TestRelations {
     AttributeReference("d", DecimalType(10, 2))(),
     AttributeReference("e", ShortType)())
 
+  val testRelation3 = LocalRelation(
+    AttributeReference("e", ShortType)(),
+    AttributeReference("f", StringType)(),
+    AttributeReference("g", DoubleType)(),
+    AttributeReference("h", DecimalType(10, 2))())
+
   val nestedRelation = LocalRelation(
     AttributeReference("top", StructType(
       StructField("duplicateField", StringType) ::

http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c17be8a..a5e5f15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
       Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
   }
 
+  test("join - sorted columns not in join's outputSet") {
+    val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1)
+    val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2)
+    val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3)
+
+    checkAnswer(
+      df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2")
+        .orderBy('str_sort.asc, 'str.asc),
+      Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil)
+
+    checkAnswer(
+      df2.join(df3, $"df2.int" === $"df3.int", "inner")
+        .select($"df2.int", $"df3.int").orderBy($"df2.str".desc),
+      Row(5, 5) :: Row(1, 1) :: Nil)
+  }
+
   test("join - join using multiple columns and specifying join type") {
     val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
     val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")

http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4ff99bd..c02133f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     assert(expected === actual)
   }
 
+  test("Sorting columns are not in Filter and Project") {
+    checkAnswer(
+      upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
+      Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
+  }
+
   test("SPARK-9323: DataFrame.orderBy should support nested column name") {
     val df = sqlContext.read.json(sparkContext.makeRDD(
       """{"a": {"b": 1}}""" :: Nil))

http://git-wip-us.apache.org/repos/asf/spark/blob/8f26eb5e/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 1ada2e3..6048b8f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
         """.stripMargin), (2 to 6).map(i => Row(i)))
   }
 
-  test("window function: udaf with aggregate expressin") {
+  test("window function: udaf with aggregate expression") {
     val data = Seq(
       WindowData(1, "a", 5),
       WindowData(2, "a", 6),
@@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
       ).map(i => Row(i._1, i._2, i._3, i._4)))
   }
 
+  test("window function: Sorting columns are not in Project") {
+    val data = Seq(
+      WindowData(1, "d", 10),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 11)
+    )
+    sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+    checkAnswer(
+      sql("select month, product, sum(product + 1) over() from windowData order by area"),
+      Seq(
+        (2, 6, 57),
+        (3, 7, 57),
+        (4, 8, 57),
+        (5, 9, 57),
+        (6, 11, 57),
+        (1, 10, 57)
+      ).map(i => Row(i._1, i._2, i._3)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1
+          |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p
+        """.stripMargin),
+      Seq(
+        ("a", 2),
+        ("b", 2),
+        ("b", 3),
+        ("c", 2),
+        ("d", 2),
+        ("c", 3)
+      ).map(i => Row(i._1, i._2)))
+
+    checkAnswer(
+      sql(
+      """
+        |select area, rank() over (partition by area order by month) as c1
+        |from windowData group by product, area, month order by product, area
+      """.stripMargin),
+      Seq(
+        ("a", 1),
+        ("b", 1),
+        ("b", 2),
+        ("c", 1),
+        ("d", 1),
+        ("c", 2)
+      ).map(i => Row(i._1, i._2)))
+  }
+
+  // todo: fix this test case by reimplementing the function ResolveAggregateFunctions
+  ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") {
+    val data = Seq(
+      WindowData(1, "d", 10),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 11)
+    )
+    sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product) over () as c from windowData
+          |where product > 3 group by area, product
+          |having avg(month) > 0 order by avg(month), product
+        """.stripMargin),
+      Seq(
+        ("a", 51),
+        ("b", 51),
+        ("b", 51),
+        ("c", 51),
+        ("c", 51),
+        ("d", 51)
+      ).map(i => Row(i._1, i._2)))
+  }
+
   test("window function: multiple window expressions in a single expression") {
     val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
     nums.registerTempTable("nums")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org