You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/25 09:13:10 UTC

spark git commit: [SPARK-13376] [SPARK-13476] [SQL] improve column pruning

Repository: spark
Updated Branches:
  refs/heads/master 264533b55 -> 07f92ef1f


[SPARK-13376] [SPARK-13476] [SQL] improve column pruning

## What changes were proposed in this pull request?

This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset).

This PR also fix a bug in Generate, it should always output UnsafeRow, added an regression test for that.

## How was this patch tested?

This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s).

Author: Davies Liu <da...@databricks.com>

Closes #11354 from davies/fix_column_pruning.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/07f92ef1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/07f92ef1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/07f92ef1

Branch: refs/heads/master
Commit: 07f92ef1fa090821bef9c60689bf41909d781ee7
Parents: 264533b
Author: Davies Liu <da...@databricks.com>
Authored: Thu Feb 25 00:13:07 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Thu Feb 25 00:13:07 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 128 +++++++++----------
 .../catalyst/optimizer/ColumnPruningSuite.scala | 128 ++++++++++++++++++-
 .../optimizer/FilterPushdownSuite.scala         |  80 ------------
 .../optimizer/JoinOptimizationSuite.scala       |   2 +-
 .../apache/spark/sql/execution/Generate.scala   |  28 ++--
 .../columnar/InMemoryColumnarTableScan.scala    |   7 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |   8 ++
 7 files changed, 215 insertions(+), 166 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1f05f20..2b80497 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case a @ Aggregate(_, _, e @ Expand(projects, output, child))
-      if (e.outputSet -- a.references).nonEmpty =>
-      val newOutput = output.filter(a.references.contains(_))
-      val newProjects = projects.map { proj =>
-        proj.zip(output).filter { case (e, a) =>
+    // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
+    case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+      p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
+    case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+      p.copy(
+        child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
+    case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
+      p.copy(child = w.copy(
+        projectList = w.projectList.filter(p.references.contains),
+        windowExpressions = w.windowExpressions.filter(p.references.contains)))
+    case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+      val newOutput = e.output.filter(a.references.contains(_))
+      val newProjects = e.projections.map { proj =>
+        proj.zip(e.output).filter { case (e, a) =>
           newOutput.contains(a)
         }.unzip._1
       }
-      a.copy(child = Expand(newProjects, newOutput, child))
+      a.copy(child = Expand(newProjects, newOutput, grandChild))
+    // TODO: support some logical plan for Dataset
 
-    case a @ Aggregate(_, _, e @ Expand(_, _, child))
-      if (child.outputSet -- e.references -- a.references).nonEmpty =>
-      a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
-
-    // Eliminate attributes that are not needed to calculate the specified aggregates.
+    // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
-      a.copy(child = Project(a.references.toSeq, child))
-
-    // Eliminate attributes that are not needed to calculate the Generate.
+      a.copy(child = prunedChild(child, a.references))
+    case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
+      w.copy(child = prunedChild(child, w.references))
+    case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+      e.copy(child = prunedChild(child, e.references))
     case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
-      g.copy(child = Project(g.references.toSeq, g.child))
+      g.copy(child = prunedChild(g.child, g.references))
 
+    // Turn off `join` for Generate if no column from it's child is used
     case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
       p.copy(child = g.copy(join = false))
 
-    case p @ Project(projectList, g: Generate) if g.join =>
-      val neededChildOutput = p.references -- g.generatorOutput ++ g.references
-      if (neededChildOutput == g.child.outputSet) {
-        p
+    // Eliminate unneeded attributes from right side of a LeftSemiJoin.
+    case j @ Join(left, right, LeftSemi, condition) =>
+      j.copy(right = prunedChild(right, j.references))
+
+    // all the columns will be used to compare, so we can't prune them
+    case p @ Project(_, _: SetOperation) => p
+    case p @ Project(_, _: Distinct) => p
+    // Eliminate unneeded attributes from children of Union.
+    case p @ Project(_, u: Union) =>
+      if ((u.outputSet -- p.references).nonEmpty) {
+        val firstChild = u.children.head
+        val newOutput = prunedChild(firstChild, p.references).output
+        // pruning the columns of all children based on the pruned first child.
+        val newChildren = u.children.map { p =>
+          val selected = p.output.zipWithIndex.filter { case (a, i) =>
+            newOutput.contains(firstChild.output(i))
+          }.map(_._1)
+          Project(selected, p)
+        }
+        p.copy(child = u.withNewChildren(newChildren))
       } else {
-        Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+        p
       }
 
-    case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
-        if (a.outputSet -- p.references).nonEmpty =>
-      Project(
-        projectList,
-        Aggregate(
-          groupingExpressions,
-          aggregateExpressions.filter(e => p.references.contains(e)),
-          child))
-
-    // Eliminate unneeded attributes from either side of a Join.
-    case Project(projectList, Join(left, right, joinType, condition)) =>
-      // Collect the list of all references required either above or to evaluate the condition.
-      val allReferences: AttributeSet =
-        AttributeSet(
-          projectList.flatMap(_.references.iterator)) ++
-          condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
-      /** Applies a projection only when the child is producing unnecessary attributes */
-      def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
+    // Can't prune the columns on LeafNode
+    case p @ Project(_, l: LeafNode) => p
 
-      Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
-
-    // Eliminate unneeded attributes from right side of a LeftSemiJoin.
-    case Join(left, right, LeftSemi, condition) =>
-      // Collect the list of all references required to evaluate the condition.
-      val allReferences: AttributeSet =
-        condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
-      Join(left, prunedChild(right, allReferences), LeftSemi, condition)
-
-    // Push down project through limit, so that we may have chance to push it further.
-    case Project(projectList, Limit(exp, child)) =>
-      Limit(exp, Project(projectList, child))
-
-    // Push down project if possible when the child is sort.
-    case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
-      if (s.references.subsetOf(p.outputSet)) {
-        s.copy(child = Project(projectList, grandChild))
+    // Eliminate no-op Projects
+    case p @ Project(projectList, child) if child.output == p.output => child
+
+    // for all other logical plans that inherits the output from it's children
+    case p @ Project(_, child) =>
+      val required = child.references ++ p.references
+      if ((child.inputSet -- required).nonEmpty) {
+        val newChildren = child.children.map(c => prunedChild(c, required))
+        p.copy(child = child.withNewChildren(newChildren))
       } else {
-        val neededReferences = s.references ++ p.references
-        if (neededReferences == grandChild.outputSet) {
-          // No column we can prune, return the original plan.
-          p
-        } else {
-          // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
-          val newProjectList = grandChild.output.filter(neededReferences.contains)
-          p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
-        }
+        p
       }
-
-    // Eliminate no-op Projects
-    case Project(projectList, child) if child.output == projectList => child
   }
 
   /** Applies a projection only when the child is producing unnecessary attributes */
   private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
     if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
-      Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+      Project(c.output.filter(allReferences.contains), c)
     } else {
       c
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index c890fff..715d01a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest {
             Seq('c, Literal.create(null, StringType), 1),
             Seq('c, 'a, 2)),
           Seq('c, 'aa.int, 'gid.int),
-          Project(Seq('c, 'a),
+          Project(Seq('a, 'c),
             input))).analyze
 
     comparePlans(optimized, expected)
   }
 
+  test("Column pruning on Filter") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
+    val expected =
+      Project('a :: Nil,
+        Filter('c > Literal(0.0),
+          Project(Seq('a, 'c), input))).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("Column pruning on except/intersect/distinct") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Except(input, input)).analyze
+    comparePlans(Optimize.execute(query), query)
+
+    val query2 = Project('a :: Nil, Intersect(input, input)).analyze
+    comparePlans(Optimize.execute(query2), query2)
+    val query3 = Project('a :: Nil, Distinct(input)).analyze
+    comparePlans(Optimize.execute(query3), query3)
+  }
+
+  test("Column pruning on Project") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+    val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
+    val expected = Project(Seq('a), input).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("column pruning for group") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+    val originalQuery =
+      testRelation
+        .groupBy('a)('a, count('b))
+        .select('a)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .groupBy('a)('a).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("column pruning for group with alias") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+    val originalQuery =
+      testRelation
+        .groupBy('a)('a as 'c, count('b))
+        .select('c)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .groupBy('a)('a as 'c).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("column pruning for Project(ne, Limit)") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+    val originalQuery =
+      testRelation
+        .select('a, 'b)
+        .limit(2)
+        .select('a)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .limit(2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("push down project past sort") {
+    val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+    val x = testRelation.subquery('x)
+
+    // push down valid
+    val originalQuery = {
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('a)
+    }
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      x.select('a)
+        .sortBy(SortOrder('a, Ascending)).analyze
+
+    comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+
+    // push down invalid
+    val originalQuery1 = {
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('b)
+    }
+
+    val optimized1 = Optimize.execute(originalQuery1.analyze)
+    val correctAnswer1 =
+      x.select('a, 'b)
+        .sortBy(SortOrder('a, Ascending))
+        .select('b).analyze
+
+    comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+  }
+
+  test("Column pruning on Union") {
+    val input1 = LocalRelation('a.int, 'b.string, 'c.double)
+    val input2 = LocalRelation('c.int, 'd.string, 'e.double)
+    val query = Project('b :: Nil,
+      Union(input1 :: input2 :: Nil)).analyze
+    val expected = Project('b :: Nil,
+      Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
   // todo: add more tests for column pruning
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 70b34cb..7d60862 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest {
         PushPredicateThroughJoin,
         PushPredicateThroughGenerate,
         PushPredicateThroughAggregate,
-        ColumnPruning,
         CollapseProject) :: Nil
   }
 
@@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("column pruning for group") {
-    val originalQuery =
-      testRelation
-        .groupBy('a)('a, count('b))
-        .select('a)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .groupBy('a)('a).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("column pruning for group with alias") {
-    val originalQuery =
-      testRelation
-        .groupBy('a)('a as 'c, count('b))
-        .select('c)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .groupBy('a)('a as 'c).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("column pruning for Project(ne, Limit)") {
-    val originalQuery =
-      testRelation
-        .select('a, 'b)
-        .limit(2)
-        .select('a)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select('a)
-        .limit(2).analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
   // After this line is unimplemented.
   test("simple push down") {
     val originalQuery =
@@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery)
   }
 
-  test("push down project past sort") {
-    val x = testRelation.subquery('x)
-
-    // push down valid
-    val originalQuery = {
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('a)
-    }
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      x.select('a)
-       .sortBy(SortOrder('a, Ascending)).analyze
-
-    comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
-
-    // push down invalid
-    val originalQuery1 = {
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('b)
-    }
-
-    val optimized1 = Optimize.execute(originalQuery1.analyze)
-    val correctAnswer1 =
-      x.select('a, 'b)
-       .sortBy(SortOrder('a, Ascending))
-       .select('b).analyze
-
-    comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
-  }
-
   test("push project and filter down into sample") {
     val x = testRelation.subquery('x)
     val originalQuery =

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 1ab53a1..2f382bb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -108,7 +108,7 @@ class JoinOptimizationSuite extends PlanTest {
       Project(Seq($"x.key", $"y.key"),
         Join(
           Project(Seq($"x.key"), SubqueryAlias("x", input)),
-          Project(Seq($"y.key"), BroadcastHint(SubqueryAlias("y", input))),
+          BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
           Inner, None)).analyze
 
     comparePlans(optimized, expected)

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 4db88a0..6bc4649 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
  * For lazy computing, be sure the generator.terminate() called in the very last
@@ -54,17 +55,19 @@ case class Generate(
     child: SparkPlan)
   extends UnaryNode {
 
+  private[sql] override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
   override def expressions: Seq[Expression] = generator :: Nil
 
   val boundGenerator = BindReferences.bindReference(generator, child.output)
 
   protected override def doExecute(): RDD[InternalRow] = {
     // boundGenerator.terminate() should be triggered after all of the rows in the partition
-    if (join) {
+    val rows = if (join) {
       child.execute().mapPartitionsInternal { iter =>
-        val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
+        val generatorNullRow = new GenericInternalRow(generator.elementTypes.size)
         val joinedRow = new JoinedRow
-        val proj = UnsafeProjection.create(output, output)
 
         iter.flatMap { row =>
           // we should always set the left (child output)
@@ -73,19 +76,26 @@ case class Generate(
           if (outer && outputRows.isEmpty) {
             joinedRow.withRight(generatorNullRow) :: Nil
           } else {
-            outputRows.map(or => joinedRow.withRight(or))
+            outputRows.map(joinedRow.withRight)
           }
-        } ++ LazyIterator(() => boundGenerator.terminate()).map { row =>
+        } ++ LazyIterator(boundGenerator.terminate).map { row =>
           // we leave the left side as the last element of its child output
           // keep it the same as Hive does
-          proj(joinedRow.withRight(row))
+          joinedRow.withRight(row)
         }
       }
     } else {
       child.execute().mapPartitionsInternal { iter =>
-        val proj = UnsafeProjection.create(output, output)
-        (iter.flatMap(row => boundGenerator.eval(row)) ++
-          LazyIterator(() => boundGenerator.terminate())).map(proj)
+        iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate)
+      }
+    }
+
+    val numOutputRows = longMetric("numOutputRows")
+    rows.mapPartitionsInternal { iter =>
+      val proj = UnsafeProjection.create(output, output)
+      iter.map { r =>
+        numOutputRows += 1
+        proj(r)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4858140..22d4278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation(
     @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
     @transient private[sql] var _statistics: Statistics = null,
     private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
-  extends LogicalPlan with MultiInstanceRelation {
+  extends logical.LeafNode with MultiInstanceRelation {
 
   override def producedAttributes: AttributeSet = outputSet
 
@@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation(
       _cachedColumnBuffers, statisticsToBePropagated, batchStats)
   }
 
-  override def children: Seq[LogicalPlan] = Seq.empty
-
   override def newInstance(): this.type = {
     new InMemoryRelation(
       output.map(_.newInstance()),

http://git-wip-us.apache.org/repos/asf/spark/blob/07f92ef1/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4930c48..b8d1b5a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -194,6 +194,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Row("a", Seq("a"), 1) :: Nil)
   }
 
+  test("sort after generate with join=true") {
+    val df = Seq((Array("a"), 1)).toDF("a", "b")
+
+    checkAnswer(
+      df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"),
+      Row(Seq("a"), 1, "a") :: Nil)
+  }
+
   test("selectExpr") {
     checkAnswer(
       testData.selectExpr("abs(key)", "value"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org