You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by tg...@apache.org on 2020/07/29 19:22:54 UTC

[spark] branch master updated: [SPARK-32332][SQL] Support columnar exchanges

This is an automated email from the ASF dual-hosted git repository.

tgraves pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a025a89  [SPARK-32332][SQL] Support columnar exchanges
a025a89 is described below

commit a025a89f4ef3a05d7e70c02f03a9826bb97eceac
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Jul 29 14:21:47 2020 -0500

    [SPARK-32332][SQL] Support columnar exchanges
    
    ### What changes were proposed in this pull request?
    
    This PR adds abstract classes for shuffle and broadcast, so that users can provide their columnar implementations.
    
    This PR updates several places to use the abstract exchange classes, and also update `AdaptiveSparkPlanExec` so that the columnar rules can see exchange nodes.
    
    This is an alternative of https://github.com/apache/spark/pull/29134 .
    Close https://github.com/apache/spark/pull/29134
    
    ### Why are the changes needed?
    
    To allow columnar exchanges.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new tests
    
    Closes #29262 from cloud-fan/columnar.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Thomas Graves <tg...@apache.org>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  30 ++++--
 .../adaptive/CustomShuffleReaderExec.scala         |  21 ++--
 .../adaptive/OptimizeLocalShuffleReader.scala      |   5 +-
 .../execution/adaptive/OptimizeSkewedJoin.scala    |   4 +-
 .../sql/execution/adaptive/QueryStageExec.scala    |  37 ++++---
 .../sql/execution/adaptive/simpleCosting.scala     |   6 +-
 .../execution/exchange/BroadcastExchangeExec.scala |  46 ++++++--
 .../execution/exchange/ShuffleExchangeExec.scala   |  57 +++++++++-
 .../execution/streaming/IncrementalExecution.scala |   4 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 120 +++++++++++++++++----
 10 files changed, 260 insertions(+), 70 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 34db0a3..b160b8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec(
     // The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
     // added by `CoalesceShufflePartitions`. So they must be executed after it.
     OptimizeSkewedJoin(conf),
-    OptimizeLocalShuffleReader(conf),
+    OptimizeLocalShuffleReader(conf)
+  )
+
+  // A list of physical optimizer rules to be applied right after a new stage is created. The input
+  // plan to these rules has exchange as its root node.
+  @transient private val postStageCreationRules = Seq(
     ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
     CollapseCodegenStages(conf)
   )
@@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec(
       }
 
       // Run the final plan when there's no more unfinished stages.
-      currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
+      currentPhysicalPlan = applyPhysicalRules(
+        result.newPlan, queryStageOptimizerRules ++ postStageCreationRules)
       isFinalPlan = true
       executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
       currentPhysicalPlan
@@ -376,10 +382,22 @@ case class AdaptiveSparkPlanExec(
   private def newQueryStage(e: Exchange): QueryStageExec = {
     val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
     val queryStage = e match {
-      case s: ShuffleExchangeExec =>
-        ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
-      case b: BroadcastExchangeExec =>
-        BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
+      case s: ShuffleExchangeLike =>
+        val newShuffle = applyPhysicalRules(
+          s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+        if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
+          throw new IllegalStateException(
+            "Custom columnar rules cannot transform shuffle node to something else.")
+        }
+        ShuffleQueryStageExec(currentStageId, newShuffle)
+      case b: BroadcastExchangeLike =>
+        val newBroadcast = applyPhysicalRules(
+          b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+        if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
+          throw new IllegalStateException(
+            "Custom columnar rules cannot transform broadcast node to something else.")
+        }
+        BroadcastQueryStageExec(currentStageId, newBroadcast)
     }
     currentStageId += 1
     setLogicalLinkForNewQueryStage(queryStage, e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
index af18ee0..49a4c25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.execution.adaptive
 
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.rdd.RDD
@@ -25,8 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 
 /**
@@ -45,6 +45,8 @@ case class CustomShuffleReaderExec private(
     assert(partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]))
   }
 
+  override def supportsColumnar: Boolean = child.supportsColumnar
+
   override def output: Seq[Attribute] = child.output
   override lazy val outputPartitioning: Partitioning = {
     // If it is a local shuffle reader with one mapper per task, then the output partitioning is
@@ -55,9 +57,9 @@ case class CustomShuffleReaderExec private(
         partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
           partitionSpecs.length) {
       child match {
-        case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
+        case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) =>
           s.child.outputPartitioning
-        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
+        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) =>
           s.child.outputPartitioning match {
             case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
             case other => other
@@ -176,18 +178,21 @@ case class CustomShuffleReaderExec private(
     }
   }
 
-  private lazy val cachedShuffleRDD: RDD[InternalRow] = {
+  private lazy val shuffleRDD: RDD[_] = {
     sendDriverMetrics()
 
     shuffleStage.map { stage =>
-      new ShuffledRowRDD(
-        stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
+      stage.shuffle.getShuffleRDD(partitionSpecs.toArray)
     }.getOrElse {
       throw new IllegalStateException("operating on canonicalized plan")
     }
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    cachedShuffleRDD
+    shuffleRDD.asInstanceOf[RDD[InternalRow]]
+  }
+
+  override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+    shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index 3620f27..45fb364 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
   private def getPartitionSpecs(
       shuffleStage: ShuffleQueryStageExec,
       advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
-    val shuffleDep = shuffleStage.shuffle.shuffleDependency
-    val numReducers = shuffleDep.partitioner.numPartitions
+    val numMappers = shuffleStage.shuffle.numMappers
+    val numReducers = shuffleStage.shuffle.numPartitions
     val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
-    val numMappers = shuffleDep.rdd.getNumPartitions
     val splitPoints = if (numMappers == 0) {
       Seq.empty
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index 627f060..a85b188 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -202,7 +202,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
         val leftParts = if (isLeftSkew && !isLeftCoalesced) {
           val reducerId = leftPartSpec.startReducerIndex
           val skewSpecs = createSkewPartitionSpecs(
-            left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
+            left.mapStats.shuffleId, reducerId, leftTargetSize)
           if (skewSpecs.isDefined) {
             logDebug(s"Left side partition $partitionIndex " +
               s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
@@ -218,7 +218,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
         val rightParts = if (isRightSkew && !isRightCoalesced) {
           val reducerId = rightPartSpec.startReducerIndex
           val skewSpecs = createSkewPartitionSpecs(
-            right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
+            right.mapStats.shuffleId, reducerId, rightTargetSize)
           if (skewSpecs.isDefined) {
             logDebug(s"Right side partition $partitionIndex " +
               s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index 4e83b43..0927ef5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange._
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -82,16 +83,18 @@ abstract class QueryStageExec extends LeafExecNode {
   def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
 
   /**
+   * Returns the runtime statistics after stage materialization.
+   */
+  def getRuntimeStatistics: Statistics
+
+  /**
    * Compute the statistics of the query stage if executed, otherwise None.
    */
   def computeStats(): Option[Statistics] = resultOption.get().map { _ =>
-    // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`.
-    val exchange = plan match {
-      case r: ReusedExchangeExec => r.child
-      case e: Exchange => e
-      case _ => throw new IllegalStateException("wrong plan for query stage:\n " + plan.treeString)
-    }
-    Statistics(sizeInBytes = exchange.metrics("dataSize").value)
+    val runtimeStats = getRuntimeStatistics
+    val dataSize = runtimeStats.sizeInBytes.max(0)
+    val numOutputRows = runtimeStats.rowCount.map(_.max(0))
+    Statistics(dataSize, numOutputRows)
   }
 
   @transient
@@ -110,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode {
 
   protected override def doPrepare(): Unit = plan.prepare()
   protected override def doExecute(): RDD[InternalRow] = plan.execute()
+  override def supportsColumnar: Boolean = plan.supportsColumnar
+  protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
   override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
   override def doCanonicalize(): SparkPlan = plan.canonicalized
 
@@ -138,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode {
 }
 
 /**
- * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
+ * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
  */
 case class ShuffleQueryStageExec(
     override val id: Int,
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val shuffle = plan match {
-    case s: ShuffleExchangeExec => s
-    case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
+    case s: ShuffleExchangeLike => s
+    case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s
     case _ =>
       throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
   }
@@ -177,22 +182,24 @@ case class ShuffleQueryStageExec(
    * this method returns None, as there is no map statistics.
    */
   def mapStats: Option[MapOutputStatistics] = {
-    assert(resultOption.get().isDefined, "ShuffleQueryStageExec should already be ready")
+    assert(resultOption.get().isDefined, s"${getClass.getSimpleName} should already be ready")
     val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics]
     Option(stats)
   }
+
+  override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics
 }
 
 /**
- * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
+ * A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]].
  */
 case class BroadcastQueryStageExec(
     override val id: Int,
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val broadcast = plan match {
-    case b: BroadcastExchangeExec => b
-    case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
+    case b: BroadcastExchangeLike => b
+    case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b
     case _ =>
       throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
   }
@@ -231,6 +238,8 @@ case class BroadcastQueryStageExec(
       broadcast.relationFuture.cancel(true)
     }
   }
+
+  override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
 }
 
 object BroadcastQueryStageExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
index 67cd720..cdc57db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.adaptive
 
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
 
 /**
  * A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
@@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost {
 
 /**
  * A simple implementation of [[CostEvaluator]], which counts the number of
- * [[ShuffleExchangeExec]] nodes in the plan.
+ * [[ShuffleExchangeLike]] nodes in the plan.
  */
 object SimpleCostEvaluator extends CostEvaluator {
 
   override def evaluateCost(plan: SparkPlan): Cost = {
     val cost = plan.collect {
-      case s: ShuffleExchangeExec => s
+      case s: ShuffleExchangeLike => s
     }.size
     SimpleCost(cost)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index d35bbe9..6d8d370 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.joins.HashedRelation
@@ -38,15 +39,42 @@ import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.{SparkFatalException, ThreadUtils}
 
 /**
+ * Common trait for all broadcast exchange implementations to facilitate pattern matching.
+ */
+trait BroadcastExchangeLike extends Exchange {
+
+  /**
+   * The broadcast job group ID
+   */
+  def runId: UUID = UUID.randomUUID
+
+  /**
+   * The asynchronous job that prepares the broadcast relation.
+   */
+  def relationFuture: Future[broadcast.Broadcast[Any]]
+
+  /**
+   * For registering callbacks on `relationFuture`.
+   * Note that calling this method may not start the execution of broadcast job.
+   */
+  def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
+
+  /**
+   * Returns the runtime statistics after broadcast materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
+/**
  * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
  * a transformed SparkPlan.
  */
 case class BroadcastExchangeExec(
     mode: BroadcastMode,
-    child: SparkPlan) extends Exchange {
+    child: SparkPlan) extends BroadcastExchangeLike {
   import BroadcastExchangeExec._
 
-  private[sql] val runId: UUID = UUID.randomUUID
+  override val runId: UUID = UUID.randomUUID
 
   override lazy val metrics = Map(
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -60,21 +88,23 @@ case class BroadcastExchangeExec(
     BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
   }
 
+  override def runtimeStatistics: Statistics = {
+    val dataSize = metrics("dataSize").value
+    Statistics(dataSize)
+  }
+
   @transient
   private lazy val promise = Promise[broadcast.Broadcast[Any]]()
 
-  /**
-   * For registering callbacks on `relationFuture`.
-   * Note that calling this field will not start the execution of broadcast job.
-   */
   @transient
-  lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future
+  override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
+    promise.future
 
   @transient
   private val timeout: Long = SQLConf.get.broadcastTimeout
 
   @transient
-  private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+  override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
     SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
       sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
           try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index b06742e..30c9f0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -30,8 +30,9 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
@@ -41,12 +42,48 @@ import org.apache.spark.util.MutablePair
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
 
 /**
+ * Common trait for all shuffle exchange implementations to facilitate pattern matching.
+ */
+trait ShuffleExchangeLike extends Exchange {
+
+  /**
+   * Returns the number of mappers of this shuffle.
+   */
+  def numMappers: Int
+
+  /**
+   * Returns the shuffle partition number.
+   */
+  def numPartitions: Int
+
+  /**
+   * Returns whether the shuffle partition number can be changed.
+   */
+  def canChangeNumPartitions: Boolean
+
+  /**
+   * The asynchronous job that materializes the shuffle.
+   */
+  def mapOutputStatisticsFuture: Future[MapOutputStatistics]
+
+  /**
+   * Returns the shuffle RDD with specified partition specs.
+   */
+  def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_]
+
+  /**
+   * Returns the runtime statistics after shuffle materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
+/**
  * Performs a shuffle that will result in the desired partitioning.
  */
 case class ShuffleExchangeExec(
     override val outputPartitioning: Partitioning,
     child: SparkPlan,
-    canChangeNumPartitions: Boolean = true) extends Exchange {
+    canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike {
 
   private lazy val writeMetrics =
     SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
@@ -64,7 +101,7 @@ case class ShuffleExchangeExec(
   @transient lazy val inputRDD: RDD[InternalRow] = child.execute()
 
   // 'mapOutputStatisticsFuture' is only needed when enable AQE.
-  @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
+  @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
     if (inputRDD.getNumPartitions == 0) {
       Future.successful(null)
     } else {
@@ -72,6 +109,20 @@ case class ShuffleExchangeExec(
     }
   }
 
+  override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
+
+  override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
+
+  override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
+    new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
+  }
+
+  override def runtimeStatistics: Statistics = {
+    val dataSize = metrics("dataSize").value
+    val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
+    Statistics(dataSize, Some(rowCount))
+  }
+
   /**
    * A [[ShuffleDependency]] that will partition rows of its child based on
    * the partitioning scheme defined in `newPartitioning`. Those partitions of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 7773ac7..bfa60cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{LeafExecNode, LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.util.Utils
@@ -118,7 +118,7 @@ class IncrementalExecution(
           case s: StatefulOperator =>
             statefulOpFound = true
 
-          case e: ShuffleExchangeExec =>
+          case e: ShuffleExchangeLike =>
             // Don't search recursively any further as any child stateful operator as we
             // are only looking for stateful subplans that this plan has narrow dependencies on.
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 44e784d..e5e8bc6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -16,19 +16,24 @@
  */
 package org.apache.spark.sql
 
-import java.util.Locale
+import java.util.{Locale, UUID}
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import scala.concurrent.Future
+
+import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@@ -169,33 +174,61 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
   }
 
-  test("inject columnar") {
+  test("inject columnar AQE on") {
+    testInjectColumnar(true)
+  }
+
+  test("inject columnar AQE off") {
+    testInjectColumnar(false)
+  }
+
+  private def testInjectColumnar(enableAQE: Boolean): Unit = {
+    def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match {
+      case a: AdaptiveSparkPlanExec =>
+        assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
+        collectPlanSteps(a.executedPlan)
+      case _ => plan.collect {
+        case _: ReplacedRowToColumnarExec => 1
+        case _: ColumnarProjectExec => 10
+        case _: ColumnarToRowExec => 100
+        case s: QueryStageExec => collectPlanSteps(s.plan).sum
+        case _: MyShuffleExchangeExec => 1000
+        case _: MyBroadcastExchangeExec => 10000
+      }
+    }
+
     val extensions = create { extensions =>
       extensions.injectColumnar(session =>
         MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
     }
     withSession(extensions) { session =>
-      // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE
-      session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false)
+      session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
       assert(session.sessionState.columnarRules.contains(
         MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
       import session.sqlContext.implicits._
-      // repartitioning avoids having the add operation pushed up into the LocalTableScan
-      val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
-      val df = data.selectExpr("vals + 1")
-      // Verify that both pre and post processing of the plan worked.
-      val found = df.queryExecution.executedPlan.collect {
-        case rep: ReplacedRowToColumnarExec => 1
-        case proj: ColumnarProjectExec => 10
-        case c2r: ColumnarToRowExec => 100
-      }.sum
-      assert(found == 111)
+      // perform a join to inject a broadcast exchange
+      val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
+      val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
+      val data = left.join(right, $"l1" === $"r1")
+        // repartitioning avoids having the add operation pushed up into the LocalTableScan
+        .repartition(1)
+      val df = data.selectExpr("l2 + r2")
+      // execute the plan so that the final adaptive plan is available when AQE is on
+      df.collect()
+      val found = collectPlanSteps(df.queryExecution.executedPlan).sum
+      // 1 MyBroadcastExchangeExec
+      // 1 MyShuffleExchangeExec
+      // 1 ColumnarToRowExec
+      // 2 ColumnarProjectExec
+      // 1 ReplacedRowToColumnarExec
+      // so 11121 is expected.
+      assert(found == 11121)
 
       // Verify that we get back the expected, wrong, result
       val result = df.collect()
-      assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used.
-      assert(result(1).getLong(0) == 202L)
-      assert(result(2).getLong(0) == 302L)
+      assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
+      assert(result(1).getLong(0) == 201L)
+      assert(result(2).getLong(0) == 301L)
     }
   }
 
@@ -695,6 +728,16 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
   def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan =
     try {
       plan match {
+        case e: ShuffleExchangeExec =>
+          // note that this is not actually columnar but demonstrates that exchanges can
+          // be replaced.
+          val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+          MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec])
+        case e: BroadcastExchangeExec =>
+          // note that this is not actually columnar but demonstrates that exchanges can
+          // be replaced.
+          val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+          MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec])
         case plan: ProjectExec =>
           new ColumnarProjectExec(plan.projectList.map((exp) =>
             replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]),
@@ -713,6 +756,41 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
   override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan)
 }
 
+/**
+ * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
+  override def numMappers: Int = delegate.numMappers
+  override def numPartitions: Int = delegate.numPartitions
+  override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
+  override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
+    delegate.mapOutputStatisticsFuture
+  override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
+    delegate.getShuffleRDD(partitionSpecs)
+  override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+  override def child: SparkPlan = delegate.child
+  override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+  override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
+/**
+ * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike {
+  override def runId: UUID = delegate.runId
+  override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
+    delegate.relationFuture
+  override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture
+  override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+  override def child: SparkPlan = delegate.child
+  override protected def doPrepare(): Unit = delegate.prepare()
+  override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+  override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast()
+  override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
 class ReplacedRowToColumnarExec(override val child: SparkPlan)
   extends RowToColumnarExec(child) {
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org