You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2016/02/10 20:01:09 UTC

spark git commit: [SPARK-13254][SQL] Fix planning of TakeOrderedAndProject operator

Repository: spark
Updated Branches:
  refs/heads/master 80cb963ad -> 5cf20598c


[SPARK-13254][SQL] Fix planning of TakeOrderedAndProject operator

The patch for SPARK-8964 ("use Exchange to perform shuffle in Limit" / #7334) inadvertently broke the planning of the TakeOrderedAndProject operator: because ReturnAnswer was the new root of the query plan, the TakeOrderedAndProject rule was unable to match before BasicOperators.

This patch fixes this by moving the `TakeOrderedAndCollect` and `CollectLimit` rules into the same strategy.

In addition, I made changes to the TakeOrderedAndProject operator in order to make its `doExecute()` method lazy and added a new TakeOrderedAndProjectSuite which tests the new code path.

/cc davies and marmbrus for review.

Author: Josh Rosen <jo...@databricks.com>

Closes #11145 from JoshRosen/take-ordered-and-project-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5cf20598
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5cf20598
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5cf20598

Branch: refs/heads/master
Commit: 5cf20598cec4e60b53c0e40dc4243f436396e7fc
Parents: 80cb963
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Feb 10 11:00:38 2016 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Wed Feb 10 11:00:38 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkPlanner.scala      |  2 +-
 .../spark/sql/execution/SparkStrategies.scala   | 40 +++++----
 .../org/apache/spark/sql/execution/limit.scala  | 30 ++++---
 .../spark/sql/execution/PlannerSuite.scala      | 44 +++++-----
 .../execution/TakeOrderedAndProjectSuite.scala  | 85 ++++++++++++++++++++
 .../org/apache/spark/sql/hive/HiveContext.scala |  2 +-
 6 files changed, 159 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 6e9a4df..d1569a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -31,7 +31,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
     sqlContext.experimental.extraStrategies ++ (
       DataSourceStrategy ::
       DDLStrategy ::
-      TakeOrderedAndProject ::
+      SpecialLimits ::
       Aggregation ::
       LeftSemiJoin ::
       EquiJoinSelection ::

http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ee392e4..598ddd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -33,6 +33,31 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
 private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   self: SparkPlanner =>
 
+  /**
+   * Plans special cases of limit operators.
+   */
+  object SpecialLimits extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case logical.ReturnAnswer(rootPlan) => rootPlan match {
+        case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+          execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+        case logical.Limit(
+            IntegerLiteral(limit),
+            logical.Project(projectList, logical.Sort(order, true, child))) =>
+          execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
+        case logical.Limit(IntegerLiteral(limit), child) =>
+          execution.CollectLimit(limit, planLater(child)) :: Nil
+        case other => planLater(other) :: Nil
+      }
+      case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+        execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+      case logical.Limit(
+          IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
+        execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
+      case _ => Nil
+    }
+  }
+
   object LeftSemiJoin extends Strategy with PredicateHelper {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case ExtractEquiJoinKeys(
@@ -264,18 +289,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
   protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
 
-  object TakeOrderedAndProject extends Strategy {
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
-        execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
-      case logical.Limit(
-             IntegerLiteral(limit),
-             logical.Project(projectList, logical.Sort(order, true, child))) =>
-        execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
-      case _ => Nil
-    }
-  }
-
   object InMemoryScans extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
@@ -338,8 +351,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
       case logical.LocalRelation(output, data) =>
         LocalTableScan(output, data) :: Nil
-      case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
-        execution.CollectLimit(limit, planLater(child)) :: Nil
       case logical.Limit(IntegerLiteral(limit), child) =>
         val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
         val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
@@ -362,7 +373,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
-      case logical.ReturnAnswer(child) => planLater(child) :: Nil
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 256f422..04daf9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -83,8 +83,7 @@ case class TakeOrderedAndProject(
     child: SparkPlan) extends UnaryNode {
 
   override def output: Seq[Attribute] = {
-    val projectOutput = projectList.map(_.map(_.toAttribute))
-    projectOutput.getOrElse(child.output)
+    projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
   }
 
   override def outputPartitioning: Partitioning = SinglePartition
@@ -93,7 +92,7 @@ case class TakeOrderedAndProject(
   // and this ordering needs to be created on the driver in order to be passed into Spark core code.
   private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
 
-  private def collectData(): Array[InternalRow] = {
+  override def executeCollect(): Array[InternalRow] = {
     val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
     if (projectList.isDefined) {
       val proj = UnsafeProjection.create(projectList.get, child.output)
@@ -103,13 +102,26 @@ case class TakeOrderedAndProject(
     }
   }
 
-  override def executeCollect(): Array[InternalRow] = {
-    collectData()
-  }
+  private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
 
-  // TODO: Terminal split should be implemented differently from non-terminal split.
-  // TODO: Pick num splits based on |limit|.
-  protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
+  protected override def doExecute(): RDD[InternalRow] = {
+    val localTopK: RDD[InternalRow] = {
+      child.execute().map(_.copy()).mapPartitions { iter =>
+        org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord)
+      }
+    }
+    val shuffled = new ShuffledRowRDD(
+      Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer))
+    shuffled.mapPartitions { iter =>
+      val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
+      if (projectList.isDefined) {
+        val proj = UnsafeProjection.create(projectList.get, child.output)
+        topK.map(r => proj(r))
+      } else {
+        topK
+      }
+    }
+  }
 
   override def outputOrdering: Seq[SortOrder] = sortOrder
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index a64ad40..250ce8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -161,30 +162,37 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
-  test("efficient limit -> project -> sort") {
-    {
-      val query =
-        testData.select('key, 'value).sort('key).limit(2).logicalPlan
-      val planned = sqlContext.planner.TakeOrderedAndProject(query)
-      assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
-      assert(planned.head.output === testData.select('key, 'value).logicalPlan.output)
-    }
+  test("efficient terminal limit -> sort should use TakeOrderedAndProject") {
+    val query = testData.select('key, 'value).sort('key).limit(2)
+    val planned = query.queryExecution.executedPlan
+    assert(planned.isInstanceOf[execution.TakeOrderedAndProject])
+    assert(planned.output === testData.select('key, 'value).logicalPlan.output)
+  }
 
-    {
-      // We need to make sure TakeOrderedAndProject's output is correct when we push a project
-      // into it.
-      val query =
-        testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan
-      val planned = sqlContext.planner.TakeOrderedAndProject(query)
-      assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
-      assert(planned.head.output === testData.select('value, 'key).logicalPlan.output)
-    }
+  test("terminal limit -> project -> sort should use TakeOrderedAndProject") {
+    val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2)
+    val planned = query.queryExecution.executedPlan
+    assert(planned.isInstanceOf[execution.TakeOrderedAndProject])
+    assert(planned.output === testData.select('value, 'key).logicalPlan.output)
   }
 
-  test("terminal limits use CollectLimit") {
+  test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") {
     val query = testData.select('value).limit(2)
     val planned = query.queryExecution.sparkPlan
     assert(planned.isInstanceOf[CollectLimit])
+    assert(planned.output === testData.select('value).logicalPlan.output)
+  }
+
+  test("TakeOrderedAndProject can appear in the middle of plans") {
+    val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3)
+    val planned = query.queryExecution.executedPlan
+    assert(planned.find(_.isInstanceOf[TakeOrderedAndProject]).isDefined)
+  }
+
+  test("CollectLimit can appear in the middle of a plan when caching is used") {
+    val query = testData.select('key, 'value).limit(2).cache()
+    val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
+    assert(planned.child.isInstanceOf[CollectLimit])
   }
 
   test("PartitioningCollection") {

http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
new file mode 100644
index 0000000..03cb04a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+
+
+class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
+
+  private var rand: Random = _
+  private var seed: Long = 0
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    seed = System.currentTimeMillis()
+    rand = new Random(seed)
+  }
+
+  private def generateRandomInputData(): DataFrame = {
+    val schema = new StructType()
+      .add("a", IntegerType, nullable = false)
+      .add("b", IntegerType, nullable = false)
+    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
+    sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
+  }
+
+  /**
+   * Adds a no-op filter to the child plan in order to prevent executeCollect() from being
+   * called directly on the child plan.
+   */
+  private def noOpFilter(plan: SparkPlan): SparkPlan = Filter(Literal(true), plan)
+
+  val limit = 250
+  val sortOrder = 'a.desc :: 'b.desc :: Nil
+
+  test("TakeOrderedAndProject.doExecute without project") {
+    withClue(s"seed = $seed") {
+      checkThatPlansAgree(
+        generateRandomInputData(),
+        input =>
+          noOpFilter(TakeOrderedAndProject(limit, sortOrder, None, input)),
+        input =>
+          GlobalLimit(limit,
+            LocalLimit(limit,
+              Sort(sortOrder, global = true, input))),
+        sortAnswers = false)
+    }
+  }
+
+  test("TakeOrderedAndProject.doExecute with project") {
+    withClue(s"seed = $seed") {
+      checkThatPlansAgree(
+        generateRandomInputData(),
+        input =>
+          noOpFilter(TakeOrderedAndProject(limit, sortOrder, Some(Seq(input.output.last)), input)),
+        input =>
+          GlobalLimit(limit,
+            LocalLimit(limit,
+              Project(Seq(input.output.last),
+                Sort(sortOrder, global = true, input)))),
+        sortAnswers = false)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5cf20598/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 05863ae..2433b54 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -559,7 +559,7 @@ class HiveContext private[hive](
       HiveCommandStrategy(self),
       HiveDDLStrategy,
       DDLStrategy,
-      TakeOrderedAndProject,
+      SpecialLimits,
       InMemoryScans,
       HiveTableScans,
       DataSinks,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org