You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2019/03/12 21:46:00 UTC

[spark] branch master updated: [SPARK-27123][SQL] Improve CollapseProject to handle projects cross limit/repartition/sample

This is an automated email from the ASF dual-hosted git repository.

dbtsai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 78314af  [SPARK-27123][SQL] Improve CollapseProject to handle projects cross limit/repartition/sample
78314af is described below

commit 78314af580b38f773a148c6f035d2ddd79896b4c
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Tue Mar 12 21:45:40 2019 +0000

    [SPARK-27123][SQL] Improve CollapseProject to handle projects cross limit/repartition/sample
    
    ## What changes were proposed in this pull request?
    
    `CollapseProject` optimizer rule simplifies some plans by merging the adjacent projects and performing alias substitutions.
    ```scala
    scala> sql("SELECT b c FROM (SELECT a b FROM t)").explain
    == Physical Plan ==
    *(1) Project [a#5 AS c#1]
    +- Scan hive default.t [a#5], HiveTableRelation `default`.`t`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [a#5]
    ```
    
    We can do that more complex cases like the following. This PR aims to handle adjacent projects across limit/repartition/sample. Here, repartition means `Repartition`, not `RepartitionByExpression`.
    
    **BEFORE**
    ```scala
    scala> sql("SELECT b c FROM (SELECT /*+ REPARTITION(1) */ a b FROM t)").explain
    == Physical Plan ==
    *(2) Project [b#0 AS c#1]
    +- Exchange RoundRobinPartitioning(1)
       +- *(1) Project [a#5 AS b#0]
          +- Scan hive default.t [a#5], HiveTableRelation `default`.`t`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [a#5]
    ```
    
    **AFTER**
    ```scala
    scala> sql("SELECT b c FROM (SELECT /*+ REPARTITION(1) */ a b FROM t)").explain
    == Physical Plan ==
    Exchange RoundRobinPartitioning(1)
    +- *(1) Project [a#11 AS c#7]
       +- Scan hive default.t [a#11], HiveTableRelation `default`.`t`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [a#11]
    ```
    
    ## How was this patch tested?
    
    Pass the Jenkins with the newly added and updated test cases.
    
    Closes #24049 from dongjoon-hyun/SPARK-27123.
    
    Authored-by: Dongjoon Hyun <dh...@apple.com>
    Signed-off-by: DB Tsai <d_...@apple.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 26 +++++++++++++++++
 .../catalyst/optimizer/CollapseProjectSuite.scala  | 34 +++++++++++++++++++++-
 .../catalyst/optimizer/ColumnPruningSuite.scala    |  2 +-
 3 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 97a53f2..1b7ff02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -699,6 +699,24 @@ object CollapseProject extends Rule[LogicalPlan] {
         agg.copy(aggregateExpressions = buildCleanedProjectList(
           p.projectList, agg.aggregateExpressions))
       }
+    case p1 @ Project(_, g @ GlobalLimit(_, l @ LocalLimit(_, p2: Project))) =>
+      if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
+        p1
+      } else {
+        val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
+        g.copy(child = l.copy(child = p2.copy(projectList = newProjectList)))
+      }
+    case p1 @ Project(_, l @ LocalLimit(_, p2: Project)) =>
+      if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
+        p1
+      } else {
+        val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
+        l.copy(child = p2.copy(projectList = newProjectList))
+      }
+    case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) =>
+      r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList)))
+    case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
+      s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
   }
 
   private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
@@ -739,6 +757,14 @@ object CollapseProject extends Rule[LogicalPlan] {
       CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
     }
   }
+
+  private def isRenaming(list1: Seq[NamedExpression], list2: Seq[NamedExpression]): Boolean = {
+    list1.length == list2.length && list1.zip(list2).forall {
+      case (e1, e2) if e1.semanticEquals(e2) => true
+      case (Alias(a: Attribute, _), b) if a.metadata == Metadata.empty && a.name == b.name => true
+      case _ => false
+    }
+  }
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index e7a5bce..42bcd13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, Rand}
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.types.MetadataBuilder
 
@@ -138,4 +138,36 @@ class CollapseProjectSuite extends PlanTest {
     assert(projects.size === 1)
     assert(hasMetadata(optimized))
   }
+
+  test("collapse redundant alias through limit") {
+    val relation = LocalRelation('a.int, 'b.int)
+    val query = relation.select('a as 'b).limit(1).select('b as 'c).analyze
+    val optimized = Optimize.execute(query)
+    val expected = relation.select('a as 'c).limit(1).analyze
+    comparePlans(optimized, expected)
+  }
+
+  test("collapse redundant alias through local limit") {
+    val relation = LocalRelation('a.int, 'b.int)
+    val query = LocalLimit(1, relation.select('a as 'b)).select('b as 'c).analyze
+    val optimized = Optimize.execute(query)
+    val expected = LocalLimit(1, relation.select('a as 'c)).analyze
+    comparePlans(optimized, expected)
+  }
+
+  test("collapse redundant alias through repartition") {
+    val relation = LocalRelation('a.int, 'b.int)
+    val query = relation.select('a as 'b).repartition(1).select('b as 'c).analyze
+    val optimized = Optimize.execute(query)
+    val expected = relation.select('a as 'c).repartition(1).analyze
+    comparePlans(optimized, expected)
+  }
+
+  test("collapse redundant alias through sample") {
+    val relation = LocalRelation('a.int, 'b.int)
+    val query = Sample(0.0, 0.6, false, 11L, relation.select('a as 'b)).select('b as 'c).analyze
+    val optimized = Optimize.execute(query)
+    val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze
+    comparePlans(optimized, expected)
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 41bc4d8..b738f30 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -388,7 +388,7 @@ class ColumnPruningSuite extends PlanTest {
 
     val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa)
     val optimized2 = Optimize.execute(query2.analyze)
-    val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a)).select('a as 'aa)
+    val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a as 'aa))
     comparePlans(optimized2, expected2.analyze)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org