You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2020/04/20 18:59:12 UTC

[spark] branch branch-3.0 updated: [SPARK-31475][SQL] Broadcast stage in AQE did not timeout

This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 2e32160  [SPARK-31475][SQL] Broadcast stage in AQE did not timeout
2e32160 is described below

commit 2e3216012e8ad85d4cd88671493dd6e4d0e6a668
Author: Maryann Xue <ma...@gmail.com>
AuthorDate: Mon Apr 20 11:55:48 2020 -0700

    [SPARK-31475][SQL] Broadcast stage in AQE did not timeout
    
    ### What changes were proposed in this pull request?
    
    This PR adds a timeout for the Future of a BroadcastQueryStageExec to make sure it can have the same timeout behavior as a non-AQE broadcast exchange.
    
    ### Why are the changes needed?
    
    This is to make the broadcast timeout behavior in AQE consistent with that in non-AQE.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added UT.
    
    Closes #28250 from maryannxue/aqe-broadcast-timeout.
    
    Authored-by: Maryann Xue <ma...@gmail.com>
    Signed-off-by: gatorsmile <ga...@gmail.com>
    (cherry picked from commit 44d370dd4501f0a4abb7194f7cff0d346aac0992)
    Signed-off-by: gatorsmile <ga...@gmail.com>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  2 +-
 .../sql/execution/adaptive/QueryStageExec.scala    | 35 ++++++++++++++++++----
 .../execution/exchange/BroadcastExchangeExec.scala |  8 ++---
 .../sql/execution/joins/BroadcastJoinSuite.scala   | 23 ++++++++++++--
 4 files changed, 56 insertions(+), 12 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 2b46724..0ec8b5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -546,7 +546,7 @@ case class AdaptiveSparkPlanExec(
 }
 
 object AdaptiveSparkPlanExec {
-  private val executionContext = ExecutionContext.fromExecutorService(
+  private[adaptive] val executionContext = ExecutionContext.fromExecutorService(
     ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16))
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index beaa972..f414f85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.sql.execution.adaptive
 
-import scala.concurrent.Future
+import java.util.concurrent.TimeUnit
 
-import org.apache.spark.{FutureAction, MapOutputStatistics}
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -28,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.ThreadUtils
 
 /**
  * A query stage is an independent subgraph of the query plan. Query stage materializes its output
@@ -100,8 +104,8 @@ abstract class QueryStageExec extends LeafExecNode {
   override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
   override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()
 
-  override def doPrepare(): Unit = plan.prepare()
-  override def doExecute(): RDD[InternalRow] = plan.execute()
+  protected override def doPrepare(): Unit = plan.prepare()
+  protected override def doExecute(): RDD[InternalRow] = plan.execute()
   override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
   override def doCanonicalize(): SparkPlan = plan.canonicalized
 
@@ -187,8 +191,24 @@ case class BroadcastQueryStageExec(
       throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
   }
 
+  @transient private lazy val materializeWithTimeout = {
+    val broadcastFuture = broadcast.completionFuture
+    val timeout = SQLConf.get.broadcastTimeout
+    val promise = Promise[Any]()
+    val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new Runnable() {
+      override def run(): Unit = {
+        promise.tryFailure(new SparkException(s"Could not execute broadcast in $timeout secs. " +
+          s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
+          s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1"))
+      }
+    }, timeout, TimeUnit.SECONDS)
+    broadcastFuture.onComplete(_ => fail.cancel(false))(AdaptiveSparkPlanExec.executionContext)
+    Future.firstCompletedOf(
+      Seq(broadcastFuture, promise.future))(AdaptiveSparkPlanExec.executionContext)
+  }
+
   override def doMaterialize(): Future[Any] = {
-    broadcast.completionFuture
+    materializeWithTimeout
   }
 
   override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
@@ -204,3 +224,8 @@ case class BroadcastQueryStageExec(
     }
   }
 }
+
+object BroadcastQueryStageExec {
+  private val scheduledExecutor =
+    ThreadUtils.newDaemonSingleThreadScheduledExecutor("BroadcastStageTimeout")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 65e6b7c..ba5261c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -118,7 +118,7 @@ case class BroadcastExchangeExec(
               System.nanoTime() - beforeBroadcast)
             val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
             SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
-            promise.success(broadcasted)
+            promise.trySuccess(broadcasted)
             broadcasted
           } catch {
             // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
@@ -131,14 +131,14 @@ case class BroadcastExchangeExec(
                   s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
                   s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
                   .initCause(oe.getCause))
-              promise.failure(ex)
+              promise.tryFailure(ex)
               throw ex
             case e if !NonFatal(e) =>
               val ex = new SparkFatalException(e)
-              promise.failure(ex)
+              promise.tryFailure(ex)
               throw ex
             case e: Throwable =>
-              promise.failure(e)
+              promise.tryFailure(e)
               throw e
           }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 5ce758e..64ecf5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
 import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
 import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.functions._
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types.{LongType, ShortType}
  * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered
  * without serializing the hashed relation, which does not happen in local mode.
  */
-class BroadcastJoinSuite extends QueryTest with SQLTestUtils with AdaptiveSparkPlanHelper {
+abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
+  with AdaptiveSparkPlanHelper {
   import testImplicits._
 
   protected var spark: SparkSession = null
@@ -398,4 +399,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP
         }
     }
   }
+
+  test("Broadcast timeout") {
+    val timeout = 5
+    val slowUDF = udf({ x: Int => Thread.sleep(timeout * 10 * 1000); x })
+    val df1 = spark.range(10).select($"id" as 'a)
+    val df2 = spark.range(5).select(slowUDF($"id") as 'a)
+    val testDf = df1.join(broadcast(df2), "a")
+    withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> timeout.toString) {
+      val e = intercept[Exception] {
+        testDf.collect()
+      }
+      AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute broadcast in $timeout secs.")
+    }
+  }
 }
+
+class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite
+
+class BroadcastJoinSuiteAE extends BroadcastJoinSuiteBase with EnableAdaptiveExecutionSuite


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org