You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by tg...@apache.org on 2020/07/31 16:15:41 UTC
[spark] branch branch-3.0 updated: [SPARK-32332][SQL][3.0] Support
columnar exchanges
This is an automated email from the ASF dual-hosted git repository.
tgraves pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 7c91b15 [SPARK-32332][SQL][3.0] Support columnar exchanges
7c91b15 is described below
commit 7c91b15c22fe875a44e08781caf3422bfca81b19
Author: Andy Grove <an...@nvidia.com>
AuthorDate: Fri Jul 31 11:14:33 2020 -0500
[SPARK-32332][SQL][3.0] Support columnar exchanges
### What changes were proposed in this pull request?
Backports SPARK-32332 to 3.0 branch.
### Why are the changes needed?
Plugins cannot replace exchanges with columnar versions when AQE is enabled without this patch.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Tests included.
Closes #29310 from andygrove/backport-SPARK-32332.
Authored-by: Andy Grove <an...@nvidia.com>
Signed-off-by: Thomas Graves <tg...@apache.org>
---
.../execution/adaptive/AdaptiveSparkPlanExec.scala | 30 ++++--
.../adaptive/CustomShuffleReaderExec.scala | 37 ++++---
.../adaptive/OptimizeLocalShuffleReader.scala | 5 +-
.../execution/adaptive/OptimizeSkewedJoin.scala | 17 +--
.../sql/execution/adaptive/QueryStageExec.scala | 24 +++--
.../sql/execution/adaptive/simpleCosting.scala | 6 +-
.../execution/exchange/BroadcastExchangeExec.scala | 42 +++++++-
.../execution/exchange/ShuffleExchangeExec.scala | 55 +++++++++-
.../execution/streaming/IncrementalExecution.scala | 4 +-
.../spark/sql/SparkSessionExtensionSuite.scala | 120 +++++++++++++++++----
10 files changed, 272 insertions(+), 68 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 5714c33..8b59b12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec(
// The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
// added by `CoalesceShufflePartitions`. So they must be executed after it.
OptimizeSkewedJoin(conf),
- OptimizeLocalShuffleReader(conf),
+ OptimizeLocalShuffleReader(conf)
+ )
+
+ // A list of physical optimizer rules to be applied right after a new stage is created. The input
+ // plan to these rules has exchange as its root node.
+ @transient private val postStageCreationRules = Seq(
ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
CollapseCodegenStages(conf)
)
@@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec(
}
// Run the final plan when there's no more unfinished stages.
- currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
+ currentPhysicalPlan = applyPhysicalRules(
+ result.newPlan, queryStageOptimizerRules ++ postStageCreationRules)
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
@@ -375,10 +381,22 @@ case class AdaptiveSparkPlanExec(
private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
val queryStage = e match {
- case s: ShuffleExchangeExec =>
- ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
- case b: BroadcastExchangeExec =>
- BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
+ case s: ShuffleExchangeLike =>
+ val newShuffle = applyPhysicalRules(
+ s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+ if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
+ throw new IllegalStateException(
+ "Custom columnar rules cannot transform shuffle node to something else.")
+ }
+ ShuffleQueryStageExec(currentStageId, newShuffle)
+ case b: BroadcastExchangeLike =>
+ val newBroadcast = applyPhysicalRules(
+ b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
+ if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
+ throw new IllegalStateException(
+ "Custom columnar rules cannot transform broadcast node to something else.")
+ }
+ BroadcastQueryStageExec(currentStageId, newBroadcast)
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
index ba3f725..8fd5720 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.vectorized.ColumnarBatch
/**
@@ -38,6 +39,8 @@ case class CustomShuffleReaderExec private(
partitionSpecs: Seq[ShufflePartitionSpec],
description: String) extends UnaryExecNode {
+ override def supportsColumnar: Boolean = child.supportsColumnar
+
override def output: Seq[Attribute] = child.output
override lazy val outputPartitioning: Partitioning = {
// If it is a local shuffle reader with one mapper per task, then the output partitioning is
@@ -47,9 +50,9 @@ case class CustomShuffleReaderExec private(
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
partitionSpecs.length) {
child match {
- case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
+ case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) =>
s.child.outputPartitioning
- case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
+ case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) =>
s.child.outputPartitioning match {
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
case other => other
@@ -64,18 +67,24 @@ case class CustomShuffleReaderExec private(
override def stringArgs: Iterator[Any] = Iterator(description)
- private var cachedShuffleRDD: RDD[InternalRow] = null
+ private def shuffleStage = child match {
+ case stage: ShuffleQueryStageExec => Some(stage)
+ case _ => None
+ }
- override protected def doExecute(): RDD[InternalRow] = {
- if (cachedShuffleRDD == null) {
- cachedShuffleRDD = child match {
- case stage: ShuffleQueryStageExec =>
- new ShuffledRowRDD(
- stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
- case _ =>
- throw new IllegalStateException("operating on canonicalization plan")
- }
+ private lazy val shuffleRDD: RDD[_] = {
+ shuffleStage.map { stage =>
+ stage.shuffle.getShuffleRDD(partitionSpecs.toArray)
+ }.getOrElse {
+ throw new IllegalStateException("operating on canonicalized plan")
}
- cachedShuffleRDD
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ shuffleRDD.asInstanceOf[RDD[InternalRow]]
+ }
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index fb6b40c..6684376 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
private def getPartitionSpecs(
shuffleStage: ShuffleQueryStageExec,
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
- val shuffleDep = shuffleStage.shuffle.shuffleDependency
- val numReducers = shuffleDep.partitioner.numPartitions
+ val numMappers = shuffleStage.shuffle.numMappers
+ val numReducers = shuffleStage.shuffle.numPartitions
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
- val numMappers = shuffleDep.rdd.getNumPartitions
val splitPoints = if (numMappers == 0) {
Seq.empty
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index 91ae0b9..b3b3eb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.commons.io.FileUtils
-import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
+import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
@@ -197,7 +197,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
- left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
+ left.mapStats.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex is skewed, split it into " +
s"${skewSpecs.get.length} parts.")
@@ -212,7 +212,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
- right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
+ right.mapStats.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex is skewed, split it into " +
s"${skewSpecs.get.length} parts.")
@@ -287,15 +287,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
private object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
- val sizes = s.mapStats.get.bytesByPartitionId
+ val mapStats = s.mapStats.get
+ val sizes = mapStats.bytesByPartitionId
val partitions = sizes.zipWithIndex.map {
case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
}
- Some(ShuffleStageInfo(s, partitions))
+ Some(ShuffleStageInfo(s, mapStats, partitions))
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _)
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
- val sizes = s.mapStats.get.bytesByPartitionId
+ val mapStats = s.mapStats.get
+ val sizes = mapStats.bytesByPartitionId
val partitions = partitionSpecs.map {
case spec @ CoalescedPartitionSpec(start, end) =>
var sum = 0L
@@ -308,7 +310,7 @@ private object ShuffleStage {
case other => throw new IllegalArgumentException(
s"Expect CoalescedPartitionSpec but got $other")
}
- Some(ShuffleStageInfo(s, partitions))
+ Some(ShuffleStageInfo(s, mapStats, partitions))
case _ => None
}
@@ -316,6 +318,7 @@ private object ShuffleStage {
private case class ShuffleStageInfo(
shuffleStage: ShuffleQueryStageExec,
+ mapStats: MapOutputStatistics,
partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])
private class SkewDesc {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index 9a9a8b1..74fe1ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils
/**
@@ -81,6 +82,11 @@ abstract class QueryStageExec extends LeafExecNode {
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
/**
+ * Returns the runtime statistics after stage materialization.
+ */
+ def getRuntimeStatistics: Statistics
+
+ /**
* Compute the statistics of the query stage if executed, otherwise None.
*/
def computeStats(): Option[Statistics] = resultOption.map { _ =>
@@ -107,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode {
protected override def doPrepare(): Unit = plan.prepare()
protected override def doExecute(): RDD[InternalRow] = plan.execute()
+ override def supportsColumnar: Boolean = plan.supportsColumnar
+ protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
override def doCanonicalize(): SparkPlan = plan.canonicalized
@@ -135,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode {
}
/**
- * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
+ * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
*/
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {
@transient val shuffle = plan match {
- case s: ShuffleExchangeExec => s
- case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
+ case s: ShuffleExchangeLike => s
+ case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s
case _ =>
throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
}
@@ -176,18 +184,20 @@ case class ShuffleQueryStageExec(
val stats = resultOption.get.asInstanceOf[MapOutputStatistics]
Option(stats)
}
+
+ override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics
}
/**
- * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
+ * A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]].
*/
case class BroadcastQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {
@transient val broadcast = plan match {
- case b: BroadcastExchangeExec => b
- case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
+ case b: BroadcastExchangeLike => b
+ case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b
case _ =>
throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
}
@@ -224,6 +234,8 @@ case class BroadcastQueryStageExec(
broadcast.relationFuture.cancel(true)
}
}
+
+ override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
}
object BroadcastQueryStageExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
index 67cd720..cdc57db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
/**
* A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
@@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost {
/**
* A simple implementation of [[CostEvaluator]], which counts the number of
- * [[ShuffleExchangeExec]] nodes in the plan.
+ * [[ShuffleExchangeLike]] nodes in the plan.
*/
object SimpleCostEvaluator extends CostEvaluator {
override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}.size
SimpleCost(cost)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index d35bbe9..bcdaf61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.HashedRelation
@@ -38,15 +39,42 @@ import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
/**
+ * Common trait for all broadcast exchange implementations to facilitate pattern matching.
+ */
+trait BroadcastExchangeLike extends Exchange {
+
+ /**
+ * The broadcast job group ID
+ */
+ def runId: UUID = UUID.randomUUID
+
+ /**
+ * The asynchronous job that prepares the broadcast relation.
+ */
+ def relationFuture: Future[broadcast.Broadcast[Any]]
+
+ /**
+ * For registering callbacks on `relationFuture`.
+ * Note that calling this method may not start the execution of broadcast job.
+ */
+ def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
+
+ /**
+ * Returns the runtime statistics after broadcast materialization.
+ */
+ def runtimeStatistics: Statistics
+}
+
+/**
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
* a transformed SparkPlan.
*/
case class BroadcastExchangeExec(
mode: BroadcastMode,
- child: SparkPlan) extends Exchange {
+ child: SparkPlan) extends BroadcastExchangeLike {
import BroadcastExchangeExec._
- private[sql] val runId: UUID = UUID.randomUUID
+ override val runId: UUID = UUID.randomUUID
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -60,6 +88,11 @@ case class BroadcastExchangeExec(
BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
}
+ override def runtimeStatistics: Statistics = {
+ val dataSize = metrics("dataSize").value
+ Statistics(dataSize)
+ }
+
@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
@@ -68,13 +101,14 @@ case class BroadcastExchangeExec(
* Note that calling this field will not start the execution of broadcast job.
*/
@transient
- lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future
+ override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
+ promise.future
@transient
private val timeout: Long = SQLConf.get.broadcastTimeout
@transient
- private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index b06742e..b7da78c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
@@ -41,12 +42,48 @@ import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
/**
+ * Common trait for all shuffle exchange implementations to facilitate pattern matching.
+ */
+trait ShuffleExchangeLike extends Exchange {
+
+ /**
+ * Returns the number of mappers of this shuffle.
+ */
+ def numMappers: Int
+
+ /**
+ * Returns the shuffle partition number.
+ */
+ def numPartitions: Int
+
+ /**
+ * Returns whether the shuffle partition number can be changed.
+ */
+ def canChangeNumPartitions: Boolean
+
+ /**
+ * The asynchronous job that materializes the shuffle.
+ */
+ def mapOutputStatisticsFuture: Future[MapOutputStatistics]
+
+ /**
+ * Returns the shuffle RDD with specified partition specs.
+ */
+ def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_]
+
+ /**
+ * Returns the runtime statistics after shuffle materialization.
+ */
+ def runtimeStatistics: Statistics
+}
+
+/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
- canChangeNumPartitions: Boolean = true) extends Exchange {
+ canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
@@ -64,7 +101,7 @@ case class ShuffleExchangeExec(
@transient lazy val inputRDD: RDD[InternalRow] = child.execute()
// 'mapOutputStatisticsFuture' is only needed when enable AQE.
- @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
+ @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputRDD.getNumPartitions == 0) {
Future.successful(null)
} else {
@@ -72,6 +109,20 @@ case class ShuffleExchangeExec(
}
}
+ override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
+
+ override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
+
+ override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
+ new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
+ }
+
+ override def runtimeStatistics: Statistics = {
+ val dataSize = metrics("dataSize").value
+ val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
+ Statistics(dataSize, Some(rowCount))
+ }
+
/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 7773ac7..bfa60cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.Utils
@@ -118,7 +118,7 @@ class IncrementalExecution(
case s: StatefulOperator =>
statefulOpFound = true
- case e: ShuffleExchangeExec =>
+ case e: ShuffleExchangeLike =>
// Don't search recursively any further as any child stateful operator as we
// are only looking for stateful subplans that this plan has narrow dependencies on.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 44e784d..e5e8bc6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -16,19 +16,24 @@
*/
package org.apache.spark.sql
-import java.util.Locale
+import java.util.{Locale, UUID}
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import scala.concurrent.Future
+
+import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@@ -169,33 +174,61 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}
- test("inject columnar") {
+ test("inject columnar AQE on") {
+ testInjectColumnar(true)
+ }
+
+ test("inject columnar AQE off") {
+ testInjectColumnar(false)
+ }
+
+ private def testInjectColumnar(enableAQE: Boolean): Unit = {
+ def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match {
+ case a: AdaptiveSparkPlanExec =>
+ assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
+ collectPlanSteps(a.executedPlan)
+ case _ => plan.collect {
+ case _: ReplacedRowToColumnarExec => 1
+ case _: ColumnarProjectExec => 10
+ case _: ColumnarToRowExec => 100
+ case s: QueryStageExec => collectPlanSteps(s.plan).sum
+ case _: MyShuffleExchangeExec => 1000
+ case _: MyBroadcastExchangeExec => 10000
+ }
+ }
+
val extensions = create { extensions =>
extensions.injectColumnar(session =>
MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
}
withSession(extensions) { session =>
- // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE
- session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false)
+ session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
assert(session.sessionState.columnarRules.contains(
MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
- // repartitioning avoids having the add operation pushed up into the LocalTableScan
- val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
- val df = data.selectExpr("vals + 1")
- // Verify that both pre and post processing of the plan worked.
- val found = df.queryExecution.executedPlan.collect {
- case rep: ReplacedRowToColumnarExec => 1
- case proj: ColumnarProjectExec => 10
- case c2r: ColumnarToRowExec => 100
- }.sum
- assert(found == 111)
+ // perform a join to inject a broadcast exchange
+ val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
+ val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
+ val data = left.join(right, $"l1" === $"r1")
+ // repartitioning avoids having the add operation pushed up into the LocalTableScan
+ .repartition(1)
+ val df = data.selectExpr("l2 + r2")
+ // execute the plan so that the final adaptive plan is available when AQE is on
+ df.collect()
+ val found = collectPlanSteps(df.queryExecution.executedPlan).sum
+ // 1 MyBroadcastExchangeExec
+ // 1 MyShuffleExchangeExec
+ // 1 ColumnarToRowExec
+ // 2 ColumnarProjectExec
+ // 1 ReplacedRowToColumnarExec
+ // so 11121 is expected.
+ assert(found == 11121)
// Verify that we get back the expected, wrong, result
val result = df.collect()
- assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used.
- assert(result(1).getLong(0) == 202L)
- assert(result(2).getLong(0) == 302L)
+ assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
+ assert(result(1).getLong(0) == 201L)
+ assert(result(2).getLong(0) == 301L)
}
}
@@ -695,6 +728,16 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan =
try {
plan match {
+ case e: ShuffleExchangeExec =>
+ // note that this is not actually columnar but demonstrates that exchanges can
+ // be replaced.
+ val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+ MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec])
+ case e: BroadcastExchangeExec =>
+ // note that this is not actually columnar but demonstrates that exchanges can
+ // be replaced.
+ val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
+ MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec])
case plan: ProjectExec =>
new ColumnarProjectExec(plan.projectList.map((exp) =>
replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]),
@@ -713,6 +756,41 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan)
}
+/**
+ * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
+ override def numMappers: Int = delegate.numMappers
+ override def numPartitions: Int = delegate.numPartitions
+ override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
+ override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
+ delegate.mapOutputStatisticsFuture
+ override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
+ delegate.getShuffleRDD(partitionSpecs)
+ override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+ override def child: SparkPlan = delegate.child
+ override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+ override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
+/**
+ * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of
+ * whether AQE is enabled.
+ */
+case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike {
+ override def runId: UUID = delegate.runId
+ override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
+ delegate.relationFuture
+ override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture
+ override def runtimeStatistics: Statistics = delegate.runtimeStatistics
+ override def child: SparkPlan = delegate.child
+ override protected def doPrepare(): Unit = delegate.prepare()
+ override protected def doExecute(): RDD[InternalRow] = delegate.execute()
+ override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast()
+ override def outputPartitioning: Partitioning = delegate.outputPartitioning
+}
+
class ReplacedRowToColumnarExec(override val child: SparkPlan)
extends RowToColumnarExec(child) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org