You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by tg...@apache.org on 2020/07/31 16:15:41 UTC

[spark] branch branch-3.0 updated: [SPARK-32332][SQL][3.0] Support columnar exchanges

This is an automated email from the ASF dual-hosted git repository.

tgraves pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 7c91b15  [SPARK-32332][SQL][3.0] Support columnar exchanges
7c91b15 is described below

commit 7c91b15c22fe875a44e08781caf3422bfca81b19
Author: Andy Grove <an...@nvidia.com>
AuthorDate: Fri Jul 31 11:14:33 2020 -0500

    [SPARK-32332][SQL][3.0] Support columnar exchanges
    
    ### What changes were proposed in this pull request?
    Backports SPARK-32332 to 3.0 branch.
    
    ### Why are the changes needed?
    Plugins cannot replace exchanges with columnar versions when AQE is enabled without this patch.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Tests included.
    
    Closes #29310 from andygrove/backport-SPARK-32332.
    
    Authored-by: Andy Grove <an...@nvidia.com>
    Signed-off-by: Thomas Graves <tg...@apache.org>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  30 ++++--
 .../adaptive/CustomShuffleReaderExec.scala         |  37 ++++---
 .../adaptive/OptimizeLocalShuffleReader.scala      |   5 +-
 .../execution/adaptive/OptimizeSkewedJoin.scala    |  17 +--
 .../sql/execution/adaptive/QueryStageExec.scala    |  24 +++--
 .../sql/execution/adaptive/simpleCosting.scala     |   6 +-
 .../execution/exchange/BroadcastExchangeExec.scala |  42 +++++++-
 .../execution/exchange/ShuffleExchangeExec.scala   |  55 +++++++++-
 .../execution/streaming/IncrementalExecution.scala |   4 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 120 +++++++++++++++++----
 10 files changed, 272 insertions(+), 68 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 5714c33..8b59b12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec(
     // The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
     // added by `CoalesceShufflePartitions`. So they must be executed after it.
     OptimizeSkewedJoin(conf),
-    OptimizeLocalShuffleReader(conf),
+    OptimizeLocalShuffleReader(conf)
+  )
+
+  // A list of physical optimizer rules to be applied right after a new stage is created. The input
+  // plan to these rules has exchange as its root node.
+  @transient private val postStageCreationRules = Seq(
     ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
     CollapseCodegenStages(conf)
   )
@@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec(
       }
 
       // Run the final plan when there's no more unfinished stages.
-      currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
+      currentPhysicalPlan = applyPhysicalRules(
+        result.newPlan, queryStageOptimizerRules ++ postStageCreationRules)
       isFinalPlan = true
       executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
       currentPhysicalPlan
@@ -375,10 +381,22 @@ case class AdaptiveSparkPlanExec(
   private def newQueryStage(e: Exchange): QueryStageExec = {
     val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
     val queryStage = e match {
-      case s: ShuffleExchangeExec =>
-        ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
-      case b: BroadcastExchangeExec =>
-        BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
+      case s: ShuffleExchangeLike =>
+        val newShuffle = applyPhysicalRules(
+          s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+        if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
+          throw new IllegalStateException(
+            "Custom columnar rules cannot transform shuffle node to something else.")
+        }
+        ShuffleQueryStageExec(currentStageId, newShuffle)
+      case b: BroadcastExchangeLike =>
+        val newBroadcast = applyPhysicalRules(
+          b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+        if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
+          throw new IllegalStateException(
+            "Custom columnar rules cannot transform broadcast node to something else.")
+        }
+        BroadcastQueryStageExec(currentStageId, newBroadcast)
     }
     currentStageId += 1
     setLogicalLinkForNewQueryStage(queryStage, e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
index ba3f725..8fd5720 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 
 /**
@@ -38,6 +39,8 @@ case class CustomShuffleReaderExec private(
     partitionSpecs: Seq[ShufflePartitionSpec],
     description: String) extends UnaryExecNode {
 
+  override def supportsColumnar: Boolean = child.supportsColumnar
+
   override def output: Seq[Attribute] = child.output
   override lazy val outputPartitioning: Partitioning = {
     // If it is a local shuffle reader with one mapper per task, then the output partitioning is
@@ -47,9 +50,9 @@ case class CustomShuffleReaderExec private(
         partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
           partitionSpecs.length) {
       child match {
-        case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
+        case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) =>
           s.child.outputPartitioning
-        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
+        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) =>
           s.child.outputPartitioning match {
             case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
             case other => other
@@ -64,18 +67,24 @@ case class CustomShuffleReaderExec private(
 
   override def stringArgs: Iterator[Any] = Iterator(description)
 
-  private var cachedShuffleRDD: RDD[InternalRow] = null
+  private def shuffleStage = child match {
+    case stage: ShuffleQueryStageExec => Some(stage)
+    case _ => None
+  }
 
-  override protected def doExecute(): RDD[InternalRow] = {
-    if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = child match {
-        case stage: ShuffleQueryStageExec =>
-          new ShuffledRowRDD(
-            stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
+  private lazy val shuffleRDD: RDD[_] = {
+    shuffleStage.map { stage =>
+      stage.shuffle.getShuffleRDD(partitionSpecs.toArray)
+    }.getOrElse {
+      throw new IllegalStateException("operating on canonicalized plan")
     }
-    cachedShuffleRDD
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    shuffleRDD.asInstanceOf[RDD[InternalRow]]
+  }
+
+  override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+    shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index fb6b40c..6684376 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
   private def getPartitionSpecs(
       shuffleStage: ShuffleQueryStageExec,
       advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
-    val shuffleDep = shuffleStage.shuffle.shuffleDependency
-    val numReducers = shuffleDep.partitioner.numPartitions
+    val numMappers = shuffleStage.shuffle.numMappers
+    val numReducers = shuffleStage.shuffle.numPartitions
     val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
-    val numMappers = shuffleDep.rdd.getNumPartitions
     val splitPoints = if (numMappers == 0) {
       Seq.empty
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index 91ae0b9..b3b3eb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.commons.io.FileUtils
 
-import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
+import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
@@ -197,7 +197,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
         val leftParts = if (isLeftSkew && !isLeftCoalesced) {
           val reducerId = leftPartSpec.startReducerIndex
           val skewSpecs = createSkewPartitionSpecs(
-            left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
+            left.mapStats.shuffleId, reducerId, leftTargetSize)
           if (skewSpecs.isDefined) {
             logDebug(s"Left side partition $partitionIndex is skewed, split it into " +
               s"${skewSpecs.get.length} parts.")
@@ -212,7 +212,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
         val rightParts = if (isRightSkew && !isRightCoalesced) {
           val reducerId = rightPartSpec.startReducerIndex
           val skewSpecs = createSkewPartitionSpecs(
-            right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
+            right.mapStats.shuffleId, reducerId, rightTargetSize)
           if (skewSpecs.isDefined) {
             logDebug(s"Right side partition $partitionIndex is skewed, split it into " +
               s"${skewSpecs.get.length} parts.")
@@ -287,15 +287,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
 private object ShuffleStage {
   def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
     case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
-      val sizes = s.mapStats.get.bytesByPartitionId
+      val mapStats = s.mapStats.get
+      val sizes = mapStats.bytesByPartitionId
       val partitions = sizes.zipWithIndex.map {
         case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
       }
-      Some(ShuffleStageInfo(s, partitions))
+      Some(ShuffleStageInfo(s, mapStats, partitions))
 
     case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _)
       if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
-      val sizes = s.mapStats.get.bytesByPartitionId
+      val mapStats = s.mapStats.get
+      val sizes = mapStats.bytesByPartitionId
       val partitions = partitionSpecs.map {
         case spec @ CoalescedPartitionSpec(start, end) =>
           var sum = 0L
@@ -308,7 +310,7 @@ private object ShuffleStage {
         case other => throw new IllegalArgumentException(
           s"Expect CoalescedPartitionSpec but got $other")
       }
-      Some(ShuffleStageInfo(s, partitions))
+      Some(ShuffleStageInfo(s, mapStats, partitions))
 
     case _ => None
   }
@@ -316,6 +318,7 @@ private object ShuffleStage {
 
 private case class ShuffleStageInfo(
     shuffleStage: ShuffleQueryStageExec,
+    mapStats: MapOutputStatistics,
     partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])
 
 private class SkewDesc {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index 9a9a8b1..74fe1ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange._
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -81,6 +82,11 @@ abstract class QueryStageExec extends LeafExecNode {
   def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
 
   /**
+   * Returns the runtime statistics after stage materialization.
+   */
+  def getRuntimeStatistics: Statistics
+
+  /**
    * Compute the statistics of the query stage if executed, otherwise None.
    */
   def computeStats(): Option[Statistics] = resultOption.map { _ =>
@@ -107,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode {
 
   protected override def doPrepare(): Unit = plan.prepare()
   protected override def doExecute(): RDD[InternalRow] = plan.execute()
+  override def supportsColumnar: Boolean = plan.supportsColumnar
+  protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
   override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
   override def doCanonicalize(): SparkPlan = plan.canonicalized
 
@@ -135,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode {
 }
 
 /**
- * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
+ * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
  */
 case class ShuffleQueryStageExec(
     override val id: Int,
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val shuffle = plan match {
-    case s: ShuffleExchangeExec => s
-    case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
+    case s: ShuffleExchangeLike => s
+    case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s
     case _ =>
       throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
   }
@@ -176,18 +184,20 @@ case class ShuffleQueryStageExec(
     val stats = resultOption.get.asInstanceOf[MapOutputStatistics]
     Option(stats)
   }
+
+  override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics
 }
 
 /**
- * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
+ * A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]].
  */
 case class BroadcastQueryStageExec(
     override val id: Int,
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val broadcast = plan match {
-    case b: BroadcastExchangeExec => b
-    case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
+    case b: BroadcastExchangeLike => b
+    case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b
     case _ =>
       throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
   }
@@ -224,6 +234,8 @@ case class BroadcastQueryStageExec(
       broadcast.relationFuture.cancel(true)
     }
   }
+
+  override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
 }
 
 object BroadcastQueryStageExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
index 67cd720..cdc57db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.adaptive
 
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
 
 /**
  * A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
@@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost {
 
 /**
  * A simple implementation of [[CostEvaluator]], which counts the number of
- * [[ShuffleExchangeExec]] nodes in the plan.
+ * [[ShuffleExchangeLike]] nodes in the plan.
  */
 object SimpleCostEvaluator extends CostEvaluator {
 
   override def evaluateCost(plan: SparkPlan): Cost = {
     val cost = plan.collect {
-      case s: ShuffleExchangeExec => s
+      case s: ShuffleExchangeLike => s
     }.size
     SimpleCost(cost)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index d35bbe9..bcdaf61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.joins.HashedRelation
@@ -38,15 +39,42 @@ import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.{SparkFatalException, ThreadUtils}
 
 /**
+ * Common trait for all broadcast exchange implementations to facilitate pattern matching.
+ */
+trait BroadcastExchangeLike extends Exchange {
+
+  /**
+   * The broadcast job group ID
+   */
+  def runId: UUID = UUID.randomUUID
+
+  /**
+   * The asynchronous job that prepares the broadcast relation.
+   */
+  def relationFuture: Future[broadcast.Broadcast[Any]]
+
+  /**
+   * For registering callbacks on `relationFuture`.
+   * Note that calling this method may not start the execution of broadcast job.
+   */
+  def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
+
+  /**
+   * Returns the runtime statistics after broadcast materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
+/**
  * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
  * a transformed SparkPlan.
  */
 case class BroadcastExchangeExec(
     mode: BroadcastMode,
-    child: SparkPlan) extends Exchange {
+    child: SparkPlan) extends BroadcastExchangeLike {
   import BroadcastExchangeExec._
 
-  private[sql] val runId: UUID = UUID.randomUUID
+  override val runId: UUID = UUID.randomUUID
 
   override lazy val metrics = Map(
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -60,6 +88,11 @@ case class BroadcastExchangeExec(
     BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
   }
 
+  override def runtimeStatistics: Statistics = {
+    val dataSize = metrics("dataSize").value
+    Statistics(dataSize)
+  }
+
   @transient
   private lazy val promise = Promise[broadcast.Broadcast[Any]]()
 
@@ -68,13 +101,14 @@ case class BroadcastExchangeExec(
    * Note that calling this field will not start the execution of broadcast job.
    */
   @transient
-  lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future
+  override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
+    promise.future
 
   @transient
   private val timeout: Long = SQLConf.get.broadcastTimeout
 
   @transient
-  private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+  override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
     SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
       sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
           try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index b06742e..b7da78c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
@@ -41,12 +42,48 @@ import org.apache.spark.util.MutablePair
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
 
 /**
+ * Common trait for all shuffle exchange implementations to facilitate pattern matching.
+ */
+trait ShuffleExchangeLike extends Exchange {
+
+  /**
+   * Returns the number of mappers of this shuffle.
+   */
+  def numMappers: Int
+
+  /**
+   * Returns the shuffle partition number.
+   */
+  def numPartitions: Int
+
+  /**
+   * Returns whether the shuffle partition number can be changed.
+   */
+  def canChangeNumPartitions: Boolean
+
+  /**
+   * The asynchronous job that materializes the shuffle.
+   */
+  def mapOutputStatisticsFuture: Future[MapOutputStatistics]
+
+  /**
+   * Returns the shuffle RDD with specified partition specs.
+   */
+  def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_]
+
+  /**
+   * Returns the runtime statistics after shuffle materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
+/**
  * Performs a shuffle that will result in the desired partitioning.
  */
 case class ShuffleExchangeExec(
     override val outputPartitioning: Partitioning,
     child: SparkPlan,
-    canChangeNumPartitions: Boolean = true) extends Exchange {
+    canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike {
 
   private lazy val writeMetrics =
     SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
@@ -64,7 +101,7 @@ case class ShuffleExchangeExec(
   @transient lazy val inputRDD: RDD[InternalRow] = child.execute()
 
   // 'mapOutputStatisticsFuture' is only needed when enable AQE.
-  @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
+  @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
     if (inputRDD.getNumPartitions == 0) {
       Future.successful(null)
     } else {
@@ -72,6 +109,20 @@ case class ShuffleExchangeExec(
     }
   }
 
+  override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
+
+  override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
+
+  override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
+    new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
+  }
+
+  override def runtimeStatistics: Statistics = {
+    val dataSize = metrics("dataSize").value
+    val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
+    Statistics(dataSize, Some(rowCount))
+  }
+
   /**
    * A [[ShuffleDependency]] that will partition rows of its child based on
    * the partitioning scheme defined in `newPartitioning`. Those partitions of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 7773ac7..bfa60cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{LeafExecNode, LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.util.Utils
@@ -118,7 +118,7 @@ class IncrementalExecution(
           case s: StatefulOperator =>
             statefulOpFound = true
 
-          case e: ShuffleExchangeExec =>
+          case e: ShuffleExchangeLike =>
             // Don't search recursively any further as any child stateful operator as we
             // are only looking for stateful subplans that this plan has narrow dependencies on.
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 44e784d..e5e8bc6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -16,19 +16,24 @@
  */
 package org.apache.spark.sql
 
-import java.util.Locale
+import java.util.{Locale, UUID}
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import scala.concurrent.Future
+
+import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@@ -169,33 +174,61 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
   }
 
-  test("inject columnar") {
+  test("inject columnar AQE on") {
+    testInjectColumnar(true)
+  }
+
+  test("inject columnar AQE off") {
+    testInjectColumnar(false)
+  }
+
+  private def testInjectColumnar(enableAQE: Boolean): Unit = {
+    def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match {
+      case a: AdaptiveSparkPlanExec =>
+        assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
+        collectPlanSteps(a.executedPlan)
+      case _ => plan.collect {
+        case _: ReplacedRowToColumnarExec => 1
+        case _: ColumnarProjectExec => 10
+        case _: ColumnarToRowExec => 100
+        case s: QueryStageExec => collectPlanSteps(s.plan).sum
+        case _: MyShuffleExchangeExec => 1000
+        case _: MyBroadcastExchangeExec => 10000
+      }
+    }
+
     val extensions = create { extensions =>
       extensions.injectColumnar(session =>
         MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
     }
     withSession(extensions) { session =>
-      // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE
-      session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false)
+      session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
       assert(session.sessionState.columnarRules.contains(
         MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
       import session.sqlContext.implicits._
-      // repartitioning avoids having the add operation pushed up into the LocalTableScan
-      val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
-      val df = data.selectExpr("vals + 1")
-      // Verify that both pre and post processing of the plan worked.
-      val found = df.queryExecution.executedPlan.collect {
-        case rep: ReplacedRowToColumnarExec => 1
-        case proj: ColumnarProjectExec => 10
-        case c2r: ColumnarToRowExec => 100
-      }.sum
-      assert(found == 111)
+      // perform a join to inject a broadcast exchange
+      val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
+      val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
+      val data = left.join(right, $"l1" === $"r1")
+        // repartitioning avoids having the add operation pushed up into the LocalTableScan
+        .repartition(1)
+      val df = data.selectExpr("l2 + r2")
+      // execute the plan so that the final adaptive plan is available when AQE is on
+      df.collect()
+      val found = collectPlanSteps(df.queryExecution.executedPlan).sum
+      // 1 MyBroadcastExchangeExec
+      // 1 MyShuffleExchangeExec
+      // 1 ColumnarToRowExec
+      // 2 ColumnarProjectExec
+      // 1 ReplacedRowToColumnarExec
+      // so 11121 is expected.
+      assert(found == 11121)
 
       // Verify that we get back the expected, wrong, result
       val result = df.collect()
-      assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used.
-      assert(result(1).getLong(0) == 202L)
-      assert(result(2).getLong(0) == 302L)
+      assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
+      assert(result(1).getLong(0) == 201L)
+      assert(result(2).getLong(0) == 301L)
     }
   }
 
@@ -695,6 +728,16 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
   def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan =
     try {
       plan match {
+        case e: ShuffleExchangeExec =>
+          // note that this is not actually columnar but demonstrates that exchanges can
+          // be replaced.
+          val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+          MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec])
+        case e: BroadcastExchangeExec =>
+          // note that this is not actually columnar but demonstrates that exchanges can
+          // be replaced.
+          val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+          MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec])
         case plan: ProjectExec =>
           new ColumnarProjectExec(plan.projectList.map((exp) =>
             replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]),
@@ -713,6 +756,41 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
   override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan)
 }
 
+/**
+ * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
+  override def numMappers: Int = delegate.numMappers
+  override def numPartitions: Int = delegate.numPartitions
+  override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
+  override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
+    delegate.mapOutputStatisticsFuture
+  override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
+    delegate.getShuffleRDD(partitionSpecs)
+  override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+  override def child: SparkPlan = delegate.child
+  override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+  override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
+/**
+ * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike {
+  override def runId: UUID = delegate.runId
+  override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
+    delegate.relationFuture
+  override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture
+  override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+  override def child: SparkPlan = delegate.child
+  override protected def doPrepare(): Unit = delegate.prepare()
+  override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+  override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast()
+  override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
 class ReplacedRowToColumnarExec(override val child: SparkPlan)
   extends RowToColumnarExec(child) {
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org