You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2022/08/27 00:18:41 UTC

[spark] branch master updated: [SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior

This is an automated email from the ASF dual-hosted git repository.

joshrosen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1178bcecc83 [SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior
1178bcecc83 is described below

commit 1178bcecc83925674cad4364537a39eba03e423e
Author: Ziqi Liu <zi...@databricks.com>
AuthorDate: Fri Aug 26 17:18:09 2022 -0700

    [SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior
    
    ### What changes were proposed in this pull request?
    [SPARK-40211](https://issues.apache.org/jira/browse/SPARK-40211) add a `initialNumPartitions` config parameter to allow customizing initial partitions to try in `take()`
    
    ### Why are the changes needed?
    Currently, the initial partitions to try to hardcode to `1`, which might cause unnecessary overhead. By setting this new configuration to a high value we could effectively mitigate the “run multiple jobs” overhead in take behavior. We could also set it to higher-than-1-but-still-small values (like, say, 10) to achieve a middle-ground trade-off.
    
    ### Does this PR introduce _any_ user-facing change?
    NO
    
    ### How was this patch tested?
    Unit test
    
    Closes #37661 from liuzqt/SPARK-40211.
    
    Authored-by: Ziqi Liu <zi...@databricks.com>
    Signed-off-by: Josh Rosen <jo...@databricks.com>
---
 .../org/apache/spark/internal/config/package.scala |  7 ++++
 .../org/apache/spark/rdd/AsyncRDDActions.scala     | 17 +++++----
 core/src/main/scala/org/apache/spark/rdd/RDD.scala |  8 ++---
 .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 39 +++++++++++++++++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    | 12 +++++++
 .../org/apache/spark/sql/execution/SparkPlan.scala | 12 +++----
 .../org/apache/spark/sql/ConfigBehaviorSuite.scala | 41 ++++++++++++++++++++++
 7 files changed, 118 insertions(+), 18 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9d1a56843ca..07d3d3e0778 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1956,6 +1956,13 @@ package object config {
       .intConf
       .createWithDefault(10)
 
+  private[spark] val RDD_LIMIT_INITIAL_NUM_PARTITIONS =
+    ConfigBuilder("spark.rdd.limit.initialNumPartitions")
+      .version("3.4.0")
+      .intConf
+      .checkValue(_ > 0, "value should be positive")
+      .createWithDefault(1)
+
   private[spark] val RDD_LIMIT_SCALE_UP_FACTOR =
     ConfigBuilder("spark.rdd.limit.scaleUpFactor")
       .version("2.1.0")
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index d6379156ccf..9f89c82db31 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -25,6 +25,7 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter}
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_LIMIT_SCALE_UP_FACTOR}
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -72,6 +73,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
     val results = new ArrayBuffer[T]
     val totalParts = self.partitions.length
 
+    val scaleUpFactor = Math.max(self.conf.get(RDD_LIMIT_SCALE_UP_FACTOR), 2)
+
     /*
       Recursively triggers jobs to scan partitions until either the requested
       number of elements are retrieved, or the partitions to scan are exhausted.
@@ -84,18 +87,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
       } else {
         // The number of partitions to try in this iteration. It is ok for this number to be
         // greater than totalParts because we actually cap it at totalParts in runJob.
-        var numPartsToTry = 1L
+        var numPartsToTry = self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS)
         if (partsScanned > 0) {
-          // If we didn't find any rows after the previous iteration, quadruple and retry.
-          // Otherwise, interpolate the number of partitions we need to try, but overestimate it
-          // by 50%. We also cap the estimation in the end.
-          if (results.size == 0) {
-            numPartsToTry = partsScanned * 4L
+          // If we didn't find any rows after the previous iteration, multiply by
+          // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
+          // to try, but overestimate it by 50%. We also cap the estimation in the end.
+          if (results.isEmpty) {
+            numPartsToTry = partsScanned * scaleUpFactor
           } else {
             // the left side of max is >=1 whenever partsScanned >= 2
             numPartsToTry = Math.max(1,
               (1.5 * num * partsScanned / results.size).toInt - partsScanned)
-            numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L)
+            numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
           }
         }
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b7284d25122..ab175595c19 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1445,12 +1445,12 @@ abstract class RDD[T: ClassTag](
       while (buf.size < num && partsScanned < totalParts) {
         // The number of partitions to try in this iteration. It is ok for this number to be
         // greater than totalParts because we actually cap it at totalParts in runJob.
-        var numPartsToTry = 1L
+        var numPartsToTry = conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS)
         val left = num - buf.size
         if (partsScanned > 0) {
-          // If we didn't find any rows after the previous iteration, quadruple and retry.
-          // Otherwise, interpolate the number of partitions we need to try, but overestimate
-          // it by 50%. We also cap the estimation in the end.
+          // If we didn't find any rows after the previous iteration, multiply by
+          // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
+          // to try, but overestimate it by 50%. We also cap the estimation in the end.
           if (buf.isEmpty) {
             numPartsToTry = partsScanned * scaleUpFactor
           } else {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ccef00c8e9d..c64573f7a0a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.rdd
 
 import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream}
 import java.lang.management.ManagementFactory
+import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -32,8 +33,9 @@ import org.scalatest.concurrent.Eventually
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
-import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
+import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_PARALLEL_LISTING_THRESHOLD}
 import org.apache.spark.rdd.RDDSuiteUtils._
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.util.{ThreadUtils, Utils}
 
 class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
@@ -1255,6 +1257,41 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
     assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions)
   }
 
+  test("SPARK-40211: customize initialNumPartitions for take") {
+    val totalElements = 100
+    val numToTake = 50
+    val rdd = sc.parallelize(0 to totalElements, totalElements)
+    import scala.language.reflectiveCalls
+    val jobCountListener = new SparkListener {
+      private var count: AtomicInteger = new AtomicInteger(0)
+      def getCount: Int = count.get
+      def reset(): Unit = count.set(0)
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        count.incrementAndGet()
+      }
+    }
+    sc.addSparkListener(jobCountListener)
+    // with default RDD_LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
+    rdd.take(numToTake)
+    sc.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount > 1)
+    jobCountListener.reset()
+    rdd.takeAsync(numToTake).get()
+    sc.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount > 1)
+
+    // setting RDD_LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job
+    sc.conf.set(RDD_LIMIT_INITIAL_NUM_PARTITIONS, 1000)
+    jobCountListener.reset()
+    rdd.take(numToTake)
+    sc.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount == 1)
+    jobCountListener.reset()
+    rdd.takeAsync(numToTake).get()
+    sc.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount == 1)
+  }
+
   // NOTE
   // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
   // running after them and if they access sc those tests will fail as sc is already closed, because
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index eb7a6a9105e..de25c19a26e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -526,6 +526,16 @@ object SQLConf {
     .checkValue(_ >= 1, "The shuffle hash join factor cannot be negative.")
     .createWithDefault(3)
 
+  val LIMIT_INITIAL_NUM_PARTITIONS = buildConf("spark.sql.limit.initialNumPartitions")
+    .internal()
+    .doc("Initial number of partitions to try when executing a take on a query. Higher values " +
+      "lead to more partitions read. Lower values might lead to longer execution times as more" +
+      "jobs will be run")
+    .version("3.4.0")
+    .intConf
+    .checkValue(_ > 0, "value should be positive")
+    .createWithDefault(1)
+
   val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor")
     .internal()
     .doc("Minimal increase rate in number of partitions between attempts when executing a take " +
@@ -4316,6 +4326,8 @@ class SQLConf extends Serializable with Logging {
 
   def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
 
+  def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)
+
   def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)
 
   def advancedPartitionPredicatePushdownEnabled: Boolean =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 50a309d443a..a56732fdc12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -469,7 +469,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
     if (n == 0) {
       return new Array[InternalRow](0)
     }
-
+    val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
+    // TODO: refactor and reuse the code from RDD's take()
     val childRDD = getByteArrayRdd(n, takeFromEnd)
 
     val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow]
@@ -478,12 +479,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
     while (buf.length < n && partsScanned < totalParts) {
       // The number of partitions to try in this iteration. It is ok for this number to be
       // greater than totalParts because we actually cap it at totalParts in runJob.
-      var numPartsToTry = 1L
+      var numPartsToTry = conf.limitInitialNumPartitions
       if (partsScanned > 0) {
-        // If we didn't find any rows after the previous iteration, quadruple and retry.
-        // Otherwise, interpolate the number of partitions we need to try, but overestimate
-        // it by 50%. We also cap the estimation in the end.
-        val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
+        // If we didn't find any rows after the previous iteration, multiply by
+        // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
+        // to try, but overestimate it by 50%. We also cap the estimation in the end.
         if (buf.isEmpty) {
           numPartsToTry = partsScanned * limitScaleUpFactor
         } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
index 36989efbe87..9c442456ce8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.sql
 
+import java.util.concurrent.atomic.AtomicInteger
+
 import org.apache.commons.math3.stat.inference.ChiSquareTest
 
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -68,4 +71,42 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession {
     }
   }
 
+  test("SPARK-40211: customize initialNumPartitions for take") {
+    val totalElements = 100
+    val numToTake = 50
+    import scala.language.reflectiveCalls
+    val jobCountListener = new SparkListener {
+      private var count: AtomicInteger = new AtomicInteger(0)
+      def getCount: Int = count.get
+      def reset(): Unit = count.set(0)
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        count.incrementAndGet()
+      }
+    }
+    spark.sparkContext.addSparkListener(jobCountListener)
+    val df = spark.range(0, totalElements, 1, totalElements)
+
+    // with default LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
+    df.take(numToTake)
+    spark.sparkContext.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount > 1)
+    jobCountListener.reset()
+    df.tail(numToTake)
+    spark.sparkContext.listenerBus.waitUntilEmpty()
+    assert(jobCountListener.getCount > 1)
+
+    // setting LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job
+
+    withSQLConf(SQLConf.LIMIT_INITIAL_NUM_PARTITIONS.key -> "1000") {
+      jobCountListener.reset()
+      df.take(numToTake)
+      spark.sparkContext.listenerBus.waitUntilEmpty()
+      assert(jobCountListener.getCount == 1)
+      jobCountListener.reset()
+      df.tail(numToTake)
+      spark.sparkContext.listenerBus.waitUntilEmpty()
+      assert(jobCountListener.getCount == 1)
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org