You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/08 16:23:31 UTC

[spark] branch branch-3.2 updated: [SPARK-35874][SQL] AQE Shuffle should wait for its subqueries to finish before materializing

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new b8d3da1  [SPARK-35874][SQL] AQE Shuffle should wait for its subqueries to finish before materializing
b8d3da1 is described below

commit b8d3da16b1bdde60b50e364a4ff98cb6bf8ccd7e
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Jul 9 00:20:50 2021 +0800

    [SPARK-35874][SQL] AQE Shuffle should wait for its subqueries to finish before materializing
    
    ### What changes were proposed in this pull request?
    
    Currently, AQE uses a very tricky way to trigger and wait for the subqueries:
    1. submitting stage calls `QueryStageExec.materialize`
    2. `QueryStageExec.materialize` calls `executeQuery`
    3. `executeQuery` does some preparation works, which goes to `QueryStageExec.doPrepare`
    4. `QueryStageExec.doPrepare` calls `prepare` of shuffle/broadcast, which triggers all the subqueries in this stage
    5. `executeQuery` then calls `waitForSubqueries`, which does nothing because `QueryStageExec` itself has no subqueries
    6. then we submit the shuffle/broadcast job, without waiting for subqueries
    7. for `ShuffleExchangeExec.mapOutputStatisticsFuture`, it calls `child.execute`, which calls `executeQuery` and wait for subqueries in the query tree of `child`
    8. The only missing case is: `ShuffleExchangeExec` itself may contain subqueries(repartition expression) and AQE doesn't wait for it.
    
    A simple fix would be overwriting `waitForSubqueries` in `QueryStageExec`, and forward the request to shuffle/broadcast, but this PR proposes a different and probably cleaner way: we follow `execute`/`doExecute` in `SparkPlan`, and add similar APIs in the AQE version of "execute", which gets a future from shuffle/broadcast.
    
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    a query fails without the fix and can run now
    
    ### How was this patch tested?
    
    new test
    
    Closes #33058 from cloud-fan/aqe.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 2df67a1a1badac643308807a46ea96577e3845bd)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/execution/adaptive/QueryStageExec.scala      | 20 +++++++++-----------
 .../execution/exchange/BroadcastExchangeExec.scala   | 11 ++++++++---
 .../sql/execution/exchange/ShuffleExchangeExec.scala | 14 ++++++++++----
 .../spark/sql/SparkSessionExtensionSuite.scala       |  4 ++--
 .../execution/adaptive/AdaptiveQueryExecSuite.scala  |  7 +++++++
 5 files changed, 36 insertions(+), 20 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index c2a9b46..f308829 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -80,7 +80,7 @@ abstract class QueryStageExec extends LeafExecNode {
    * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
    * stage is ready.
    */
-  final def materialize(): Future[Any] = executeQuery {
+  final def materialize(): Future[Any] = {
     logDebug(s"Materialize query stage ${this.getClass.getSimpleName}: $id")
     doMaterialize()
   }
@@ -119,7 +119,6 @@ abstract class QueryStageExec extends LeafExecNode {
   override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
   override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()
 
-  protected override def doPrepare(): Unit = plan.prepare()
   protected override def doExecute(): RDD[InternalRow] = plan.execute()
   override def supportsColumnar: Boolean = plan.supportsColumnar
   protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
@@ -171,7 +170,9 @@ case class ShuffleQueryStageExec(
       throw new IllegalStateException(s"wrong plan for shuffle stage:\n ${plan.treeString}")
   }
 
-  override def doMaterialize(): Future[Any] = shuffle.mapOutputStatisticsFuture
+  @transient private lazy val shuffleFuture = shuffle.submitShuffleJob
+
+  override def doMaterialize(): Future[Any] = shuffleFuture
 
   override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
     val reuse = ShuffleQueryStageExec(
@@ -182,13 +183,10 @@ case class ShuffleQueryStageExec(
     reuse
   }
 
-  override def cancel(): Unit = {
-    shuffle.mapOutputStatisticsFuture match {
-      case action: FutureAction[MapOutputStatistics]
-        if !shuffle.mapOutputStatisticsFuture.isCompleted =>
-        action.cancel()
-      case _ =>
-    }
+  override def cancel(): Unit = shuffleFuture match {
+    case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
+      action.cancel()
+    case _ =>
   }
 
   /**
@@ -224,7 +222,7 @@ case class BroadcastQueryStageExec(
   }
 
   @transient private lazy val materializeWithTimeout = {
-    val broadcastFuture = broadcast.completionFuture
+    val broadcastFuture = broadcast.submitBroadcastJob
     val timeout = conf.broadcastTimeout
     val promise = Promise[Any]()
     val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new Runnable() {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index cc12be6..7859785 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -55,10 +55,15 @@ trait BroadcastExchangeLike extends Exchange {
   def relationFuture: Future[broadcast.Broadcast[Any]]
 
   /**
-   * For registering callbacks on `relationFuture`.
-   * Note that calling this method may not start the execution of broadcast job.
+   * The asynchronous job that materializes the broadcast. It's used for registering callbacks on
+   * `relationFuture`. Note that calling this method may not start the execution of broadcast job.
+   * It also does the preparations work, such as waiting for the subqueries.
    */
-  def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
+  final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery {
+    completionFuture
+  }
+
+  protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
 
   /**
    * Returns the runtime statistics after broadcast materialization.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 1632a9e..e8cf768 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -61,9 +61,14 @@ trait ShuffleExchangeLike extends Exchange {
   def shuffleOrigin: ShuffleOrigin
 
   /**
-   * The asynchronous job that materializes the shuffle.
+   * The asynchronous job that materializes the shuffle. It also does the preparations work,
+   * such as waiting for the subqueries.
    */
-  def mapOutputStatisticsFuture: Future[MapOutputStatistics]
+  final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery {
+    mapOutputStatisticsFuture
+  }
+
+  protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]
 
   /**
    * Returns the shuffle RDD with specified partition specs.
@@ -123,13 +128,14 @@ case class ShuffleExchangeExec(
 
   override def nodeName: String = "Exchange"
 
-  private val serializer: Serializer =
+  private lazy val serializer: Serializer =
     new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
 
   @transient lazy val inputRDD: RDD[InternalRow] = child.execute()
 
   // 'mapOutputStatisticsFuture' is only needed when enable AQE.
-  @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
+  @transient
+  override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
     if (inputRDD.getNumPartitions == 0) {
       Future.successful(null)
     } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 8225204..b1c3fd5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -829,7 +829,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
     delegate.shuffleOrigin
   }
   override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
-    delegate.mapOutputStatisticsFuture
+    delegate.submitShuffleJob
   override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
     delegate.getShuffleRDD(partitionSpecs)
   override def runtimeStatistics: Statistics = delegate.runtimeStatistics
@@ -848,7 +848,7 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa
   override def runId: UUID = delegate.runId
   override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
     delegate.relationFuture
-  override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture
+  override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob
   override def runtimeStatistics: Statistics = delegate.runtimeStatistics
   override def child: SparkPlan = delegate.child
   override protected def doPrepare(): Unit = delegate.prepare()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 8bc67fc..d811ba7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1952,6 +1952,13 @@ class AdaptiveQueryExecSuite
       }
     }
   }
+
+  test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+      val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))"
+      runAdaptiveAndVerifyResult(query)
+    }
+  }
 }
 
 /**

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org