You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/04/02 21:48:12 UTC

git commit: StopAfter / TopK related changes

Repository: spark
Updated Branches:
  refs/heads/master 1faa57971 -> ed730c950


StopAfter / TopK related changes

1. Renamed StopAfter to Limit to be more consistent with naming in other relational databases.
2. Renamed TopK to TakeOrdered to be more consistent with Spark RDD API.
3. Avoid breaking lineage in Limit.
4. Added a bunch of override's to execution/basicOperators.scala.

@marmbrus @liancheng

Author: Reynold Xin <rx...@apache.org>
Author: Michael Armbrust <mi...@databricks.com>

Closes #233 from rxin/limit and squashes the following commits:

13eb12a [Reynold Xin] Merge pull request #1 from marmbrus/limit
92b9727 [Michael Armbrust] More hacks to make Maps serialize with Kryo.
4fc8b4e [Reynold Xin] Merge branch 'master' of github.com:apache/spark into limit
87b7d37 [Reynold Xin] Use the proper serializer in limit.
9b79246 [Reynold Xin] Updated doc for Limit.
47d3327 [Reynold Xin] Copy tuples in Limit before shuffle.
231af3a [Reynold Xin] Limit/TakeOrdered: 1. Renamed StopAfter to Limit to be more consistent with naming in other relational databases. 2. Renamed TopK to TakeOrdered to be more consistent with Spark RDD API. 3. Avoid breaking lineage in Limit. 4. Added a bunch of override's to execution/basicOperators.scala.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ed730c95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ed730c95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ed730c95

Branch: refs/heads/master
Commit: ed730c95026d322f4b24d3d9fe92050ffa74cf4a
Parents: 1faa579
Author: Reynold Xin <rx...@apache.org>
Authored: Wed Apr 2 12:48:04 2014 -0700
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Wed Apr 2 12:48:04 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   |  2 +-
 .../catalyst/plans/logical/basicOperators.scala |  2 +-
 .../scala/org/apache/spark/sql/SQLContext.scala |  2 +-
 .../sql/execution/SparkSqlSerializer.scala      |  6 ++
 .../spark/sql/execution/SparkStrategies.scala   | 10 +--
 .../spark/sql/execution/basicOperators.scala    | 71 +++++++++++++-------
 .../org/apache/spark/sql/hive/HiveContext.scala |  2 +-
 .../org/apache/spark/sql/hive/HiveQl.scala      |  4 +-
 8 files changed, 64 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0c851c2..8de8759 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -181,7 +181,7 @@ class SqlParser extends StandardTokenParsers {
         val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
         val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
         val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
-        val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
+        val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
         withLimit
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 9d16189..b39c2b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -130,7 +130,7 @@ case class Aggregate(
   def references = child.references
 }
 
-case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
   def output = child.output
   def references = limit.references
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 69bbbdc..f4bf00f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -145,7 +145,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
     val sparkContext = self.sparkContext
 
     val strategies: Seq[Strategy] =
-      TopK ::
+      TakeOrdered ::
       PartialAggregation ::
       HashJoin ::
       ParquetOperations ::

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 915f551..d8e1b97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -32,7 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
     kryo.setRegistrationRequired(false)
     kryo.register(classOf[MutablePair[_, _]])
     kryo.register(classOf[Array[Any]])
+    // This is kinda hacky...
     kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
+    kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
+    kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
+    kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
+    kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
+    kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
     kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e35ac0b..b3e51fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     case other => other
   }
 
-  object TopK extends Strategy {
+  object TakeOrdered extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
-        execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
+      case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
+        execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
       case _ => Nil
     }
   }
@@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           sparkContext.parallelize(data.map(r =>
             new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
         execution.ExistingRdd(output, dataAsRdd) :: Nil
-      case logical.StopAfter(IntegerLiteral(limit), child) =>
-        execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
+      case logical.Limit(IntegerLiteral(limit), child) =>
+        execution.Limit(limit, planLater(child))(sparkContext) :: Nil
       case Unions(unionChildren) =>
         execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
       case logical.Generate(generator, join, outer, _, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 65cb8f8..524e502 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -19,27 +19,28 @@ package org.apache.spark.sql.execution
 
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext
-
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.util.MutablePair
+
 
 case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
-  def output = projectList.map(_.toAttribute)
+  override def output = projectList.map(_.toAttribute)
 
-  def execute() = child.execute().mapPartitions { iter =>
+  override def execute() = child.execute().mapPartitions { iter =>
     @transient val reusableProjection = new MutableProjection(projectList)
     iter.map(reusableProjection)
   }
 }
 
 case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
-  def output = child.output
+  override def output = child.output
 
-  def execute() = child.execute().mapPartitions { iter =>
+  override def execute() = child.execute().mapPartitions { iter =>
     iter.filter(condition.apply(_).asInstanceOf[Boolean])
   }
 }
@@ -47,37 +48,59 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
 case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
     extends UnaryNode {
 
-  def output = child.output
+  override def output = child.output
 
   // TODO: How to pick seed?
-  def execute() = child.execute().sample(withReplacement, fraction, seed)
+  override def execute() = child.execute().sample(withReplacement, fraction, seed)
 }
 
 case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
   // TODO: attributes output by union should be distinct for nullability purposes
-  def output = children.head.output
-  def execute() = sc.union(children.map(_.execute()))
+  override def output = children.head.output
+  override def execute() = sc.union(children.map(_.execute()))
 
   override def otherCopyArgs = sc :: Nil
 }
 
-case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements. Note that the implementation is different depending on whether
+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
+ * invoked using execute, we first take the limit on each partition, and then repartition all the
+ * data to a single partition to compute the global limit.
+ */
+case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+  // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
+  // partition local limit -> exchange into one partition -> partition local limit again
+
   override def otherCopyArgs = sc :: Nil
 
-  def output = child.output
+  override def output = child.output
 
   override def executeCollect() = child.execute().map(_.copy()).take(limit)
 
-  // TODO: Terminal split should be implemented differently from non-terminal split.
-  // TODO: Pick num splits based on |limit|.
-  def execute() = sc.makeRDD(executeCollect(), 1)
+  override def execute() = {
+    val rdd = child.execute().mapPartitions { iter =>
+      val mutablePair = new MutablePair[Boolean, Row]()
+      iter.take(limit).map(row => mutablePair.update(false, row))
+    }
+    val part = new HashPartitioner(1)
+    val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
+    shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+    shuffled.mapPartitions(_.take(limit).map(_._2))
+  }
 }
 
-case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
-               (@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
+ * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
+                      (@transient sc: SparkContext) extends UnaryNode {
   override def otherCopyArgs = sc :: Nil
 
-  def output = child.output
+  override def output = child.output
 
   @transient
   lazy val ordering = new RowOrdering(sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
 
   // TODO: Terminal split should be implemented differently from non-terminal split.
   // TODO: Pick num splits based on |limit|.
-  def execute() = sc.makeRDD(executeCollect(), 1)
+  override def execute() = sc.makeRDD(executeCollect(), 1)
 }
 
 
@@ -101,7 +124,7 @@ case class Sort(
   @transient
   lazy val ordering = new RowOrdering(sortOrder)
 
-  def execute() = attachTree(this, "sort") {
+  override def execute() = attachTree(this, "sort") {
     // TODO: Optimize sorting operation?
     child.execute()
       .mapPartitions(
@@ -109,7 +132,7 @@ case class Sort(
         preservesPartitioning = true)
   }
 
-  def output = child.output
+  override def output = child.output
 }
 
 object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
 }
 
 case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
-  def execute() = rdd
+  override def execute() = rdd
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 197b557..46febbf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -188,7 +188,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
     val hiveContext = self
 
     override val strategies: Seq[Strategy] = Seq(
-      TopK,
+      TakeOrdered,
       ParquetOperations,
       HiveTableScans,
       DataSinks,

http://git-wip-us.apache.org/repos/asf/spark/blob/ed730c95/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 490a592..b2b03bc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -529,7 +529,7 @@ object HiveQl {
 
         val withLimit =
           limitClause.map(l => nodeToExpr(l.getChildren.head))
-            .map(StopAfter(_, withSort))
+            .map(Limit(_, withSort))
             .getOrElse(withSort)
 
         // TOK_INSERT_INTO means to add files to the table.
@@ -602,7 +602,7 @@ object HiveQl {
         case Token("TOK_TABLESPLITSAMPLE",
                Token("TOK_ROWCOUNT", Nil) ::
                Token(count, Nil) :: Nil) =>
-          StopAfter(Literal(count.toInt), relation)
+          Limit(Literal(count.toInt), relation)
         case Token("TOK_TABLESPLITSAMPLE",
                Token("TOK_PERCENT", Nil) ::
                Token(fraction, Nil) :: Nil) =>