You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/10 02:15:35 UTC

spark git commit: [SPARK-23005][CORE] Improve RDD.take on small number of partitions

Repository: spark
Updated Branches:
  refs/heads/master 2250cb75b -> 96ba217a0


[SPARK-23005][CORE] Improve RDD.take on small number of partitions

## What changes were proposed in this pull request?
In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%:
`(1.5 * num * partsScanned / buf.size).toInt`
However, when the number is small, the result of `.toInt` is not what we want.
E.g, 2.9 will become 2, which should be 3.
Use Math.ceil to fix the problem.

Also clean up the code in RDD.scala.

## How was this patch tested?

Unit test

Author: Wang Gengliang <lt...@gmail.com>

Closes #20200 from gengliangwang/Take.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96ba217a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96ba217a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96ba217a

Branch: refs/heads/master
Commit: 96ba217a06fbe1dad703447d7058cb7841653861
Parents: 2250cb7
Author: Wang Gengliang <lt...@gmail.com>
Authored: Wed Jan 10 10:15:27 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jan 10 10:15:27 2018 +0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 27 ++++++++++----------
 .../apache/spark/sql/execution/SparkPlan.scala  |  5 ++--
 2 files changed, 16 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/96ba217a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 8798dfc..7859781 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag](
   val id: Int = sc.newRddId()
 
   /** A friendly name for this RDD */
-  @transient var name: String = null
+  @transient var name: String = _
 
   /** Assign a name to this RDD */
   def setName(_name: String): this.type = {
@@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag](
 
   // Our dependencies and partitions will be gotten by calling subclass's methods below, and will
   // be overwritten when we're checkpointed
-  private var dependencies_ : Seq[Dependency[_]] = null
-  @transient private var partitions_ : Array[Partition] = null
+  private var dependencies_ : Seq[Dependency[_]] = _
+  @transient private var partitions_ : Array[Partition] = _
 
   /** An Option holding our checkpoint RDD, if we are checkpointed */
   private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
@@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag](
   private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
     val ancestors = new mutable.HashSet[RDD[_]]
 
-    def visit(rdd: RDD[_]) {
+    def visit(rdd: RDD[_]): Unit = {
       val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
       val narrowParents = narrowDependencies.map(_.rdd)
       val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
@@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag](
     if (shuffle) {
       /** Distributes elements evenly across output partitions, starting from a random partition. */
       val distributePartition = (index: Int, items: Iterator[T]) => {
-        var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions)
+        var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions)
         items.map { t =>
           // Note that the hash code of the key will just be the key itself. The HashPartitioner
           // will mod it with the number of total partitions.
@@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag](
     def collectPartition(p: Int): Array[T] = {
       sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
     }
-    (0 until partitions.length).iterator.flatMap(i => collectPartition(i))
+    partitions.indices.iterator.flatMap(i => collectPartition(i))
   }
 
   /**
@@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag](
         // The number of partitions to try in this iteration. It is ok for this number to be
         // greater than totalParts because we actually cap it at totalParts in runJob.
         var numPartsToTry = 1L
+        val left = num - buf.size
         if (partsScanned > 0) {
           // If we didn't find any rows after the previous iteration, quadruple and retry.
           // Otherwise, interpolate the number of partitions we need to try, but overestimate
@@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag](
           if (buf.isEmpty) {
             numPartsToTry = partsScanned * scaleUpFactor
           } else {
-            // the left side of max is >=1 whenever partsScanned >= 2
-            numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
+            // As left > 0, numPartsToTry is always >= 1
+            numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
             numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
           }
         }
 
-        val left = num - buf.size
         val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
         val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
 
@@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag](
   // an RDD and its parent in every batch, in which case the parent may never be checkpointed
   // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
   private val checkpointAllMarkedAncestors =
-    Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
-      .map(_.toBoolean).getOrElse(false)
+    Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean)
 
   /** Returns the first parent RDD */
   protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
@@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag](
   }
 
   /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */
-  protected[spark] def parent[U: ClassTag](j: Int) = {
+  protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = {
     dependencies(j).rdd.asInstanceOf[RDD[U]]
   }
 
@@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag](
    * collected. Subclasses of RDD may override this method for implementing their own cleaning
    * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
    */
-  protected def clearDependencies() {
+  protected def clearDependencies(): Unit = {
     dependencies_ = null
   }
 
@@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag](
           val lastDepStrings =
             debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true)
 
-          (frontDepStrings ++ lastDepStrings)
+          frontDepStrings ++ lastDepStrings
       }
     }
     // The first RDD in the dependency stack has no parents, so no need for a +-

http://git-wip-us.apache.org/repos/asf/spark/blob/96ba217a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 82300ef..398758a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
         if (buf.isEmpty) {
           numPartsToTry = partsScanned * limitScaleUpFactor
         } else {
-          // the left side of max is >=1 whenever partsScanned >= 2
-          numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
+          val left = n - buf.size
+          // As left > 0, numPartsToTry is always >= 1
+          numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
           numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
         }
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org