You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/24 03:19:32 UTC

spark git commit: [SPARK-13376] [SQL] improve column pruning

Repository: spark
Updated Branches:
  refs/heads/master 230bbeaa6 -> e9533b419


[SPARK-13376] [SQL] improve column pruning

## What changes were proposed in this pull request?

This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset).

## How was the this patch tested?

This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s).

Author: Davies Liu <da...@databricks.com>

Closes #11256 from davies/fix_column_pruning.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9533b41
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9533b41
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9533b41

Branch: refs/heads/master
Commit: e9533b419e3a87589313350310890ce0caf73dbb
Parents: 230bbea
Author: Davies Liu <da...@databricks.com>
Authored: Tue Feb 23 18:19:22 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Feb 23 18:19:22 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 128 +++++++++----------
 .../catalyst/optimizer/ColumnPruningSuite.scala | 128 ++++++++++++++++++-
 .../optimizer/FilterPushdownSuite.scala         |  80 ------------
 .../columnar/InMemoryColumnarTableScan.scala    |   7 +-
 4 files changed, 187 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9533b41/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1f05f20..2b80497 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case a @ Aggregate(_, _, e @ Expand(projects, output, child))
-      if (e.outputSet -- a.references).nonEmpty =>
-      val newOutput = output.filter(a.references.contains(_))
-      val newProjects = projects.map { proj =>
-        proj.zip(output).filter { case (e, a) =>
+    // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
+    case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+      p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
+    case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+      p.copy(
+        child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
+    case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
+      p.copy(child = w.copy(
+        projectList = w.projectList.filter(p.references.contains),
+        windowExpressions = w.windowExpressions.filter(p.references.contains)))
+    case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+      val newOutput = e.output.filter(a.references.contains(_))
+      val newProjects = e.projections.map { proj =>
+        proj.zip(e.output).filter { case (e, a) =>
           newOutput.contains(a)
         }.unzip._1
       }
-      a.copy(child = Expand(newProjects, newOutput, child))
+      a.copy(child = Expand(newProjects, newOutput, grandChild))
+    // TODO: support some logical plan for Dataset
 
-    case a @ Aggregate(_, _, e @ Expand(_, _, child))
-      if (child.outputSet -- e.references -- a.references).nonEmpty =>
-      a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
-
-    // Eliminate attributes that are not needed to calculate the specified aggregates.
+    // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
-      a.copy(child = Project(a.references.toSeq, child))
-
-    // Eliminate attributes that are not needed to calculate the Generate.
+      a.copy(child = prunedChild(child, a.references))
+    case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
+      w.copy(child = prunedChild(child, w.references))
+    case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+      e.copy(child = prunedChild(child, e.references))
     case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
-      g.copy(child = Project(g.references.toSeq, g.child))
+      g.copy(child = prunedChild(g.child, g.references))
 
+    // Turn off `join` for Generate if no column from it's child is used
     case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
       p.copy(child = g.copy(join = false))
 
-    case p @ Project(projectList, g: Generate) if g.join =>
-      val neededChildOutput = p.references -- g.generatorOutput ++ g.references
-      if (neededChildOutput == g.child.outputSet) {
-        p
+    // Eliminate unneeded attributes from right side of a LeftSemiJoin.
+    case j @ Join(left, right, LeftSemi, condition) =>
+      j.copy(right = prunedChild(right, j.references))
+
+    // all the columns will be used to compare, so we can't prune them
+    case p @ Project(_, _: SetOperation) => p
+    case p @ Project(_, _: Distinct) => p
+    // Eliminate unneeded attributes from children of Union.
+    case p @ Project(_, u: Union) =>
+      if ((u.outputSet -- p.references).nonEmpty) {
+        val firstChild = u.children.head
+        val newOutput = prunedChild(firstChild, p.references).output
+        // pruning the columns of all children based on the pruned first child.
+        val newChildren = u.children.map { p =>
+          val selected = p.output.zipWithIndex.filter { case (a, i) =>
+            newOutput.contains(firstChild.output(i))
+          }.map(_._1)
+          Project(selected, p)
+        }
+        p.copy(child = u.withNewChildren(newChildren))
       } else {
-        Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+        p
       }
 
-    case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
-        if (a.outputSet -- p.references).nonEmpty =>
-      Project(
-        projectList,
-        Aggregate(
-          groupingExpressions,
-          aggregateExpressions.filter(e => p.references.contains(e)),
-          child))
-
-    // Eliminate unneeded attributes from either side of a Join.
-    case Project(projectList, Join(left, right, joinType, condition)) =>
-      // Collect the list of all references required either above or to evaluate the condition.
-      val allReferences: AttributeSet =
-        AttributeSet(
-          projectList.flatMap(_.references.iterator)) ++
-          condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
-      /** Applies a projection only when the child is producing unnecessary attributes */
-      def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
+    // Can't prune the columns on LeafNode
+    case p @ Project(_, l: LeafNode) => p
 
-      Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
-
-    // Eliminate unneeded attributes from right side of a LeftSemiJoin.
-    case Join(left, right, LeftSemi, condition) =>
-      // Collect the list of all references required to evaluate the condition.
-      val allReferences: AttributeSet =
-        condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
-      Join(left, prunedChild(right, allReferences), LeftSemi, condition)
-
-    // Push down project through limit, so that we may have chance to push it further.
-    case Project(projectList, Limit(exp, child)) =>
-      Limit(exp, Project(projectList, child))
-
-    // Push down project if possible when the child is sort.
-    case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
-      if (s.references.subsetOf(p.outputSet)) {
-        s.copy(child = Project(projectList, grandChild))
+    // Eliminate no-op Projects
+    case p @ Project(projectList, child) if child.output == p.output => child
+
+    // for all other logical plans that inherits the output from it's children
+    case p @ Project(_, child) =>
+      val required = child.references ++ p.references
+      if ((child.inputSet -- required).nonEmpty) {
+        val newChildren = child.children.map(c => prunedChild(c, required))
+        p.copy(child = child.withNewChildren(newChildren))
       } else {
-        val neededReferences = s.references ++ p.references
-        if (neededReferences == grandChild.outputSet) {
-          // No column we can prune, return the original plan.
-          p
-        } else {
-          // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
-          val newProjectList = grandChild.output.filter(neededReferences.contains)
-          p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
-        }
+        p
       }
-
-    // Eliminate no-op Projects
-    case Project(projectList, child) if child.output == projectList => child
   }
 
   /** Applies a projection only when the child is producing unnecessary attributes */
   private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
     if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
-      Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+      Project(c.output.filter(allReferences.contains), c)
     } else {
       c
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9533b41/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index c890fff..715d01a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest {
             Seq('c, Literal.create(null, StringType), 1),
             Seq('c, 'a, 2)),
           Seq('c, 'aa.int, 'gid.int),
-          Project(Seq('c, 'a),
+          Project(Seq('a, 'c),
             input))).analyze
 
     comparePlans(optimized, expected)
   }
 
+  test("Column pruning on Filter") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
+    val expected =
+      Project('a :: Nil,
+        Filter('c > Literal(0.0),
+          Project(Seq('a, 'c), input))).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("Column pruning on except/intersect/distinct") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Except(input, input)).analyze
+    comparePlans(Optimize.execute(query), query)
+
+    val query2 = Project('a :: Nil, Intersect(input, input)).analyze
+    comparePlans(Optimize.execute(query2), query2)
+    val query3 = Project('a :: Nil, Distinct(input)).analyze
+    comparePlans(Optimize.execute(query3), query3)
+  }
+
+  test("Column pruning on Project") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
+    val expected = Project(Seq('a), input).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("column pruning for group") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+    val originalQuery =
+      testRelation
+        .groupBy('a)('a, count('b))
+        .select('a)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .groupBy('a)('a).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("column pruning for group with alias") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+    val originalQuery =
+      testRelation
+        .groupBy('a)('a as 'c, count('b))
+        .select('c)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .groupBy('a)('a as 'c).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("column pruning for Project(ne, Limit)") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+    val originalQuery =
+      testRelation
+        .select('a, 'b)
+        .limit(2)
+        .select('a)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .limit(2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("push down project past sort") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+    val x = testRelation.subquery('x)
+
+    // push down valid
+    val originalQuery = {
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('a)
+    }
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      x.select('a)
+        .sortBy(SortOrder('a, Ascending)).analyze
+
+    comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+
+    // push down invalid
+    val originalQuery1 = {
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('b)
+    }
+
+    val optimized1 = Optimize.execute(originalQuery1.analyze)
+    val correctAnswer1 =
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('b).analyze
+
+    comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+  }
+
+  test("Column pruning on Union") {
+    val input1 = LocalRelation('a.int, 'b.string, 'c.double)
+    val input2 = LocalRelation('c.int, 'd.string, 'e.double)
+    val query = Project('b :: Nil,
+      Union(input1 :: input2 :: Nil)).analyze
+    val expected = Project('b :: Nil,
+      Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
   // todo: add more tests for column pruning
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9533b41/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 70b34cb..7d60862 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest {
         PushPredicateThroughJoin,
         PushPredicateThroughGenerate,
         PushPredicateThroughAggregate,
-        ColumnPruning,
         CollapseProject) :: Nil
   }
 
@@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("column pruning for group") {
-    val originalQuery =
-      testRelation
-        .groupBy('a)('a, count('b))
-        .select('a)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .groupBy('a)('a).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("column pruning for group with alias") {
-    val originalQuery =
-      testRelation
-        .groupBy('a)('a as 'c, count('b))
-        .select('c)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .groupBy('a)('a as 'c).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("column pruning for Project(ne, Limit)") {
-    val originalQuery =
-      testRelation
-        .select('a, 'b)
-        .limit(2)
-        .select('a)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .limit(2).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
   // After this line is unimplemented.
   test("simple push down") {
     val originalQuery =
@@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery)
   }
 
-  test("push down project past sort") {
-    val x = testRelation.subquery('x)
-
-    // push down valid
-    val originalQuery = {
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('a)
-    }
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      x.select('a)
-       .sortBy(SortOrder('a, Ascending)).analyze
-
-    comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
-
-    // push down invalid
-    val originalQuery1 = {
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('b)
-    }
-
-    val optimized1 = Optimize.execute(originalQuery1.analyze)
-    val correctAnswer1 =
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('b).analyze
-
-    comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
-  }
-
   test("push project and filter down into sample") {
     val x = testRelation.subquery('x)
     val originalQuery =

http://git-wip-us.apache.org/repos/asf/spark/blob/e9533b41/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4858140..22d4278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation(
     @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
     @transient private[sql] var _statistics: Statistics = null,
     private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
-  extends LogicalPlan with MultiInstanceRelation {
+  extends logical.LeafNode with MultiInstanceRelation {
 
   override def producedAttributes: AttributeSet = outputSet
 
@@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation(
       _cachedColumnBuffers, statisticsToBePropagated, batchStats)
   }
 
-  override def children: Seq[LogicalPlan] = Seq.empty
-
   override def newInstance(): this.type = {
     new InMemoryRelation(
       output.map(_.newInstance()),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org