You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/07/02 12:55:51 UTC

[spark] branch master updated: [SPARK-25353][SQL] executeTake in SparkPlan is modified to avoid unnecessary decoding.

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ff1ac5  [SPARK-25353][SQL] executeTake in SparkPlan is modified to avoid unnecessary decoding.
2ff1ac5 is described below

commit 2ff1ac5d9f95576bb6998b989483bb060ab42228
Author: Dooyoung Hwang <do...@sk.com>
AuthorDate: Tue Jul 2 20:55:24 2019 +0800

    [SPARK-25353][SQL] executeTake in SparkPlan is modified to avoid unnecessary decoding.
    
    ## What changes were proposed in this pull request?
    In some cases, executeTake in SparkPlan could decode more than necessary.
    
    For example, in case of below odd/even number partitioning, total row's count from partitions will be 100, although it is limited with 51. And 'executeTake' in SparkPlan decodes all of them, "49" rows of which are unnecessarily decoded.
    
    ```scala
    spark.sparkContext.parallelize((0 until 100).map(i => (i, 1))).toDF()
          .repartitionByRange(2, $"_1" % 2).limit(51).collect()
    ```
    
    By using a iterator of the scalar collection, we can make ensure that at most n rows are decoded.
    
    ## How was this patch tested?
    Existing unit tests that call limit function of DataFrame.
    
    testOnly *SQLQuerySuite
    testOnly *DataFrameSuite
    
    Closes #22347 from Dooyoung-Hwang/refactor_execute_take.
    
    Authored-by: Dooyoung Hwang <do...@sk.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/sql/execution/SparkPlan.scala | 34 ++++++++++++----------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 6deb90c..2baf2e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -409,12 +409,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       return new Array[InternalRow](0)
     }
 
-    val childRDD = getByteArrayRdd(n).map(_._2)
+    val childRDD = getByteArrayRdd(n)
 
     val buf = new ArrayBuffer[InternalRow]
     val totalParts = childRDD.partitions.length
     var partsScanned = 0
-    while (buf.size < n && partsScanned < totalParts) {
+    while (buf.length < n && partsScanned < totalParts) {
       // The number of partitions to try in this iteration. It is ok for this number to be
       // greater than totalParts because we actually cap it at totalParts in runJob.
       var numPartsToTry = 1L
@@ -426,28 +426,32 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
         if (buf.isEmpty) {
           numPartsToTry = partsScanned * limitScaleUpFactor
         } else {
-          val left = n - buf.size
+          val left = n - buf.length
           // As left > 0, numPartsToTry is always >= 1
-          numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
+          numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.length).toInt
           numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
         }
       }
 
       val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
       val sc = sqlContext.sparkContext
-      val res = sc.runJob(childRDD,
-        (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty[Byte], p)
-
-      buf ++= res.flatMap(decodeUnsafeRows)
-
+      val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
+        if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)
+
+      var i = 0
+      while (buf.length < n && i < res.length) {
+        val rows = decodeUnsafeRows(res(i)._2)
+        val rowsToTake = if (n - buf.length >= res(i)._1) {
+          rows.toArray
+        } else {
+          rows.take(n - buf.length).toArray
+        }
+        buf ++= rowsToTake
+        i += 1
+      }
       partsScanned += p.size
     }
-
-    if (buf.size > n) {
-      buf.take(n).toArray
-    } else {
-      buf.toArray
-    }
+    buf.toArray
   }
 
   protected def newMutableProjection(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org