You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/08/03 10:30:20 UTC
[spark] branch branch-3.2 updated: [SPARK-36315][SQL] Only skip
AQEShuffleReadRule in the final stage if it breaks the distribution
requirement
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 8d817dc [SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement
8d817dc is described below
commit 8d817dcf3084d56da22b909d578a644143f775d5
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Aug 3 18:28:52 2021 +0800
[SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/30494
This PR proposes a new way to optimize the final query stage in AQE. We first collect the effective user-specified repartition (semantic-wise, user-specified repartition is only effective if it's the root node or under a few simple nodes), and get the required distribution for the final plan. When we optimize the final query stage, we skip certain `AQEShuffleReadRule` if it breaks the required distribution.
### Why are the changes needed?
The current solution for optimizing the final query stage is pretty hacky and overkill. As an example, the newly added rule `OptimizeSkewInRebalancePartitions` can hardly apply as it's very common that the query plan has shuffles with origin `ENSURE_REQUIREMENTS`, which is not supported by `OptimizeSkewInRebalancePartitions`.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
updated tests
Closes #33541 from cloud-fan/aqe.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit dd80457ffb1c129a1ca3c53bcf3ea5feed7ebc57)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/execution/QueryExecution.scala | 2 +-
.../execution/adaptive/AQEShuffleReadExec.scala | 36 ++++-
.../execution/adaptive/AQEShuffleReadRule.scala | 12 +-
.../spark/sql/execution/adaptive/AQEUtils.scala | 60 ++++++++
.../execution/adaptive/AdaptiveSparkPlanExec.scala | 63 ++++----
.../adaptive/CoalesceShufflePartitions.scala | 10 +-
.../adaptive/OptimizeShuffleWithLocalRead.scala | 12 +-
.../OptimizeSkewInRebalancePartitions.scala | 8 +-
.../execution/adaptive/OptimizeSkewedJoin.scala | 15 +-
.../execution/exchange/EnsureRequirements.scala | 16 +-
.../execution/exchange/ValidateRequirements.scala | 4 +
.../apache/spark/sql/execution/PlannerSuite.scala | 2 +
.../adaptive/AdaptiveQueryExecSuite.scala | 170 +++++++++++----------
.../exchange/EnsureRequirementsSuite.scala | 27 +---
.../sql/execution/joins/BroadcastJoinSuite.scala | 2 +
.../sql/execution/joins/ExistenceJoinSuite.scala | 2 +
.../spark/sql/execution/joins/InnerJoinSuite.scala | 2 +
.../spark/sql/execution/joins/OuterJoinSuite.scala | 2 +
18 files changed, 274 insertions(+), 171 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 5a654ad..6c16dce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -409,7 +409,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
- EnsureRequirements,
+ EnsureRequirements(),
// `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index 0768b9b..af62157 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -22,13 +22,13 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch
-
/**
* A wrapper of shuffle query stage, which follows the given partition arrangement.
*
@@ -51,6 +51,7 @@ case class AQEShuffleReadExec private(
override def supportsColumnar: Boolean = child.supportsColumnar
override def output: Seq[Attribute] = child.output
+
override lazy val outputPartitioning: Partitioning = {
// If it is a local shuffle read with one mapper per task, then the output partitioning is
// the same as the plan before shuffle.
@@ -69,6 +70,21 @@ case class AQEShuffleReadExec private(
case _ =>
throw new IllegalStateException("operating on canonicalization plan")
}
+ } else if (isCoalescedRead) {
+ // For coalesced shuffle read, the data distribution is not changed, only the number of
+ // partitions is changed.
+ child.outputPartitioning match {
+ case h: HashPartitioning =>
+ CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length))
+ case r: RangePartitioning =>
+ CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length))
+ // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
+ // `RoundRobinPartitioning` but we don't need to retain the number of partitions.
+ case r: RoundRobinPartitioning =>
+ r.copy(numPartitions = partitionSpecs.length)
+ case other => throw new IllegalStateException(
+ "Unexpected partitioning for coalesced shuffle read: " + other)
+ }
} else {
UnknownPartitioning(partitionSpecs.length)
}
@@ -92,7 +108,7 @@ case class AQEShuffleReadExec private(
/**
* Returns true iff some partitions were actually combined
*/
- private def isCoalesced(spec: ShufflePartitionSpec) = spec match {
+ private def isCoalescedSpec(spec: ShufflePartitionSpec) = spec match {
case CoalescedPartitionSpec(0, 0, _) => true
case s: CoalescedPartitionSpec => s.endReducerIndex - s.startReducerIndex > 1
case _ => false
@@ -102,7 +118,7 @@ case class AQEShuffleReadExec private(
* Returns true iff some non-empty partitions were combined
*/
def hasCoalescedPartition: Boolean = {
- partitionSpecs.exists(isCoalesced)
+ partitionSpecs.exists(isCoalescedSpec)
}
def hasSkewedPartition: Boolean =
@@ -112,6 +128,16 @@ case class AQEShuffleReadExec private(
partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) ||
partitionSpecs.exists(_.isInstanceOf[CoalescedMapperPartitionSpec])
+ def isCoalescedRead: Boolean = {
+ partitionSpecs.sliding(2).forall {
+ // A single partition spec which is `CoalescedPartitionSpec` also means coalesced read.
+ case Seq(_: CoalescedPartitionSpec) => true
+ case Seq(l: CoalescedPartitionSpec, r: CoalescedPartitionSpec) =>
+ l.endReducerIndex <= r.startReducerIndex
+ case _ => false
+ }
+ }
+
private def shuffleStage = child match {
case stage: ShuffleQueryStageExec => Some(stage)
case _ => None
@@ -159,7 +185,7 @@ case class AQEShuffleReadExec private(
if (hasCoalescedPartition) {
val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions")
- val x = partitionSpecs.count(isCoalesced)
+ val x = partitionSpecs.count(isCoalescedSpec)
numCoalescedPartitionsMetric.set(x)
driverAccumUpdates += numCoalescedPartitionsMetric.id -> x
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
index 1c7f2ea..c303e85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
@@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.ShuffleOrigin
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
/**
- * Adaptive Query Execution rule that may create [[AQEShuffleReadExec]] on top of query stages.
+ * A rule that may create [[AQEShuffleReadExec]] on top of [[ShuffleQueryStageExec]] and change the
+ * plan output partitioning. The AQE framework will skip the rule if it leads to extra shuffles.
*/
trait AQEShuffleReadRule extends Rule[SparkPlan] {
-
/**
* Returns the list of [[ShuffleOrigin]]s supported by this rule.
*/
- def supportedShuffleOrigins: Seq[ShuffleOrigin]
+ protected def supportedShuffleOrigins: Seq[ShuffleOrigin]
- def mayAddExtraShuffles: Boolean = false
+ protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+ supportedShuffleOrigins.contains(shuffle.shuffleOrigin)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala
new file mode 100644
index 0000000..277af21
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, HashPartitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{CollectMetricsExec, FilterExec, ProjectExec, SortExec, SparkPlan}
+import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
+
+object AQEUtils {
+
+ // Analyze the given plan and calculate the required distribution of this plan w.r.t. the
+ // user-specified repartition.
+ def getRequiredDistribution(p: SparkPlan): Option[Distribution] = p match {
+ // User-specified repartition is only effective when it's the root node, or under
+ // Project/Filter/LocalSort/CollectMetrics.
+ // Note: we only care about `HashPartitioning` as `EnsureRequirements` can only optimize out
+ // user-specified repartition with `HashPartitioning`.
+ case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin)
+ if shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM =>
+ val numPartitions = if (shuffleOrigin == REPARTITION_BY_NUM) {
+ Some(h.numPartitions)
+ } else {
+ None
+ }
+ Some(HashClusteredDistribution(h.expressions, numPartitions))
+ case f: FilterExec => getRequiredDistribution(f.child)
+ case s: SortExec if !s.global => getRequiredDistribution(s.child)
+ case c: CollectMetricsExec => getRequiredDistribution(c.child)
+ case p: ProjectExec =>
+ getRequiredDistribution(p.child).flatMap {
+ case h: HashClusteredDistribution =>
+ if (h.expressions.forall(e => p.projectList.exists(_.semanticEquals(e)))) {
+ Some(h)
+ } else {
+ // It's possible that the user-specified repartition is effective but the output
+ // partitioning is not retained, e.g. `df.repartition(a, b).select(c)`. We can't
+ // handle this case with required distribution. Here we return None and later on
+ // `EnsureRequirements` will skip optimizing out the user-specified repartition.
+ None
+ }
+ case other => Some(other)
+ }
+ case _ => Some(UnspecifiedDistribution)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index c03bb4b..9db4574 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -83,12 +84,28 @@ case class AdaptiveSparkPlanExec(
// The logical plan optimizer for re-optimizing the current logical plan.
@transient private val optimizer = new AQEOptimizer(conf)
+ // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
+ // change its output partitioning. This assumption is not true in AQE. Here we check the
+ // `inputPlan` which has not been processed by `EnsureRequirements` yet, to find out the
+ // effective user-specified repartition. Later on, the AQE framework will make sure the final
+ // output partitioning is not changed w.r.t the effective user-specified repartition.
+ @transient private val requiredDistribution: Option[Distribution] = if (isSubquery) {
+ // Subquery output does not need a specific output partitioning.
+ Some(UnspecifiedDistribution)
+ } else {
+ AQEUtils.getRequiredDistribution(inputPlan)
+ }
+
// A list of physical plan rules to be applied before creation of query stages. The physical
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
RemoveRedundantProjects,
- EnsureRequirements,
+ // For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for
+ // the final plan, but we do need to respect the user-specified repartition. Here we ask
+ // `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
+ // around this case.
+ EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan
) ++ context.session.sessionState.queryStagePrepRules
@@ -114,33 +131,24 @@ case class AdaptiveSparkPlanExec(
CollapseCodegenStages()
) ++ context.session.sessionState.postStageCreationRules
- // The partitioning of the query output depends on the shuffle(s) in the final stage. If the
- // original plan contains a repartition operator, we need to preserve the specified partitioning,
- // whether or not the repartition-introduced shuffle is optimized out because of an underlying
- // shuffle of the same partitioning. Thus, we need to exclude some `AQEShuffleReadRule`s
- // from the final stage, depending on the presence and properties of repartition operators.
- private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = {
- val origins = inputPlan.collect {
- case s: ShuffleExchangeLike => s.shuffleOrigin
- }
- val allRules = queryStageOptimizerRules ++ postStageCreationRules
- allRules.filter {
- case c: AQEShuffleReadRule =>
- origins.forall(c.supportedShuffleOrigins.contains)
- case _ => true
- }
- }
-
- private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
- val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
+ private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
+ val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
- case c: AQEShuffleReadRule if c.mayAddExtraShuffles =>
- if (ValidateRequirements.validate(applied)) {
+ case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
+ val distribution = if (isFinalStage) {
+ // If `requiredDistribution` is None, it means `EnsureRequirements` will not optimize
+ // out the user-specified repartition, thus we don't have a distribution requirement
+ // for the final plan.
+ requiredDistribution.getOrElse(UnspecifiedDistribution)
+ } else {
+ UnspecifiedDistribution
+ }
+ if (ValidateRequirements.validate(applied, distribution)) {
applied
} else {
- logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " +
- "will be introduced.")
+ logDebug(s"Rule ${rule.ruleName} is not applied as it breaks the " +
+ "distribution requirement of the query plan.")
latestPlan
}
case _ => applied
@@ -303,7 +311,10 @@ case class AdaptiveSparkPlanExec(
}
// Run the final plan when there's no more unfinished stages.
- currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules)
+ currentPhysicalPlan = applyPhysicalRules(
+ optimizeQueryStage(result.newPlan, isFinalStage = true),
+ postStageCreationRules,
+ Some((planChangeLogger, "AQE Post Stage Creation")))
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
@@ -520,7 +531,7 @@ case class AdaptiveSparkPlanExec(
}
private def newQueryStage(e: Exchange): QueryStageExec = {
- val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules)
+ val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false)
val queryStage = e match {
case s: ShuffleExchangeLike =>
val newShuffle = applyPhysicalRules(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
index 7f3e453..75c53b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
@@ -33,6 +33,10 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
Seq(ENSURE_REQUIREMENTS, REPARTITION_BY_COL, REBALANCE_PARTITIONS_BY_NONE,
REBALANCE_PARTITIONS_BY_COL)
+ override def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+ shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
+ }
+
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.coalesceShufflePartitionsEnabled) {
return plan
@@ -52,7 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
val shuffleStageInfos = collectShuffleStageInfos(plan)
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
- if (!shuffleStageInfos.forall(s => supportCoalesce(s.shuffleStage.shuffle))) {
+ if (!shuffleStageInfos.forall(s => isSupported(s.shuffleStage.shuffle))) {
plan
} else {
// Ideally, this rule should simply coalesce partition w.r.t. the target size specified by
@@ -106,10 +110,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
}.getOrElse(plan)
case other => other.mapChildren(updateShuffleReads(_, specsMap))
}
-
- private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
- s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
- }
}
private class ShuffleStageInfo(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
index 844acbd..cf1c7ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
@@ -38,7 +38,9 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE)
- override def mayAddExtraShuffles: Boolean = true
+ override protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+ shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
+ }
// The build side is a broadcast query stage which should have been optimized using local read
// already. So we only need to deal with probe side here.
@@ -136,14 +138,10 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
def canUseLocalShuffleRead(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
- s.mapStats.isDefined && supportLocalRead(s.shuffle)
+ s.mapStats.isDefined && isSupported(s.shuffle)
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) =>
- s.mapStats.isDefined && supportLocalRead(s.shuffle) &&
+ s.mapStats.isDefined && isSupported(s.shuffle) &&
s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS
case _ => false
}
-
- private def supportLocalRead(s: ShuffleExchangeLike): Boolean = {
- s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
index dc437403..1752907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
@@ -38,7 +38,8 @@ import org.apache.spark.sql.internal.SQLConf
* ShuffleQueryStageExec.
*/
object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
- override def supportedShuffleOrigins: Seq[ShuffleOrigin] =
+
+ override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
Seq(REBALANCE_PARTITIONS_BY_NONE, REBALANCE_PARTITIONS_BY_COL)
/**
@@ -92,9 +93,8 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
}
plan match {
- case shuffle: ShuffleQueryStageExec
- if supportedShuffleOrigins.contains(shuffle.shuffle.shuffleOrigin) =>
- tryOptimizeSkewedPartitions(shuffle)
+ case stage: ShuffleQueryStageExec if isSupported(stage.shuffle) =>
+ tryOptimizeSkewedPartitions(stage)
case _ => plan
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index fbfbce6..88abe68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -52,8 +52,6 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
- override def mayAddExtraShuffles: Boolean = true
-
/**
* A partition is considered as a skewed partition if its size is larger than the median
* partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than
@@ -257,13 +255,12 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
plan
}
}
-}
-private object ShuffleStage {
- def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
- case s: ShuffleQueryStageExec if s.mapStats.isDefined &&
- OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
- Some(s)
- case _ => None
+ object ShuffleStage {
+ def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
+ case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) =>
+ Some(s)
+ case _ => None
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index d71933a..23716f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -32,8 +32,14 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that
* the input partition ordering requirements are met.
+ *
+ * @param optimizeOutRepartition A flag to indicate that if this rule should optimize out
+ * user-specified repartition shuffles or not. This is mostly true,
+ * but can be false in AQE when AQE optimization may change the plan
+ * output partitioning and need to retain the user-specified
+ * repartition shuffles in the plan.
*/
-object EnsureRequirements extends Rule[SparkPlan] {
+case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
@@ -249,13 +255,9 @@ object EnsureRequirements extends Rule[SparkPlan] {
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- // TODO: remove this after we create a physical operator for `RepartitionByExpression`.
- // SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM
- // shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL,
- // it is a special shuffle used to rebalance partitions.
- // So, here we only remove REPARTITION_BY_COL in AQE.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
- if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled =>
+ if optimizeOutRepartition &&
+ (shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
index 6964d9c..5003db6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
@@ -30,6 +30,10 @@ import org.apache.spark.sql.execution._
*/
object ValidateRequirements extends Logging {
+ def validate(plan: SparkPlan, requiredDistribution: Distribution): Boolean = {
+ validate(plan) && plan.outputPartitioning.satisfies(requiredDistribution)
+ }
+
def validate(plan: SparkPlan): Boolean = {
plan.children.forall(validate) && validateInternal(plan)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index fad6ed1..df310cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -40,6 +40,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
setupTestData()
+ private val EnsureRequirements = new EnsureRequirements()
+
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = spark.sessionState.planner
import planner._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index dda94f1..ca8295e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1496,6 +1496,60 @@ class AdaptiveQueryExecSuite
}.isDefined
}
+ def checkBHJ(
+ df: Dataset[Row],
+ optimizeOutRepartition: Boolean,
+ probeSideLocalRead: Boolean,
+ probeSideCoalescedRead: Boolean): Unit = {
+ df.collect()
+ val plan = df.queryExecution.executedPlan
+ // There should be only one shuffle that can't do local read, which is either the top shuffle
+ // from repartition, or BHJ probe side shuffle.
+ checkNumLocalShuffleReads(plan, 1)
+ assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+ val bhj = findTopLevelBroadcastHashJoin(plan)
+ assert(bhj.length == 1)
+
+ // Build side should do local read.
+ val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec])
+ assert(buildSide.isDefined)
+ assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+
+ val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec])
+ if (probeSideLocalRead || probeSideCoalescedRead) {
+ assert(probeSide.isDefined)
+ if (probeSideLocalRead) {
+ assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+ } else {
+ assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
+ }
+ } else {
+ assert(probeSide.isEmpty)
+ }
+ }
+
+ def checkSMJ(
+ df: Dataset[Row],
+ optimizeOutRepartition: Boolean,
+ optimizeSkewJoin: Boolean,
+ coalescedRead: Boolean): Unit = {
+ df.collect()
+ val plan = df.queryExecution.executedPlan
+ assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.length == 1)
+ assert(smj.head.isSkewJoin == optimizeSkewJoin)
+ val aqeReads = collect(smj.head) {
+ case c: AQEShuffleReadExec => c
+ }
+ if (coalescedRead || optimizeSkewJoin) {
+ assert(aqeReads.length == 2)
+ if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition))
+ } else {
+ assert(aqeReads.isEmpty)
+ }
+ }
+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val df = sql(
@@ -1509,44 +1563,25 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
// Repartition with no partition num specified.
- val dfRepartition = df.repartition('b)
- dfRepartition.collect()
- val plan = dfRepartition.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(plan))
- val bhj = findTopLevelBroadcastHashJoin(plan)
- assert(bhj.length == 1)
- checkNumLocalShuffleReads(plan, 1)
- // Probe side is coalesced.
- val aqeRead = bhj.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
- assert(aqeRead.isDefined)
- assert(aqeRead.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
-
- // Repartition with partition default num specified.
- val dfRepartitionWithNum = df.repartition(5, 'b)
- dfRepartitionWithNum.collect()
- val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out.
- assert(hasRepartitionShuffle(planWithNum))
- val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
- assert(bhjWithNum.length == 1)
- checkNumLocalShuffleReads(planWithNum, 1)
- // Probe side is coalesced.
- assert(bhjWithNum.head.right.find(_.isInstanceOf[AQEShuffleReadExec]).nonEmpty)
-
- // Repartition with partition non-default num specified.
- val dfRepartitionWithNum2 = df.repartition(3, 'b)
- dfRepartitionWithNum2.collect()
- val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out, and this is the only shuffle that
- // does not have local shuffle read.
- assert(hasRepartitionShuffle(planWithNum2))
- val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
- assert(bhjWithNum2.length == 1)
- checkNumLocalShuffleReads(planWithNum2, 1)
- val aqeRead2 = bhjWithNum2.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
- assert(aqeRead2.isDefined)
- assert(aqeRead2.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+ checkBHJ(df.repartition('b),
+ // The top shuffle from repartition is optimized out.
+ optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkBHJ(df.repartition(5, 'b),
+ // The top shuffle from repartition is optimized out
+ // The final plan must have 5 partitions, no optimization can be made to the probe side.
+ optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkBHJ(df.repartition(4, 'b),
+ // The top shuffle from repartition is not optimized out
+ optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
+
+ // Repartition by col and project away the partition cols
+ checkBHJ(df.repartition('b).select('key),
+ // The top shuffle from repartition is not optimized out
+ optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
}
// Force skew join
@@ -1556,46 +1591,25 @@ class AdaptiveQueryExecSuite
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
// Repartition with no partition num specified.
- val dfRepartition = df.repartition('b)
- dfRepartition.collect()
- val plan = dfRepartition.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(plan))
- val smj = findTopLevelSortMergeJoin(plan)
- assert(smj.length == 1)
- // No skew join due to the repartition.
- assert(!smj.head.isSkewJoin)
- // Both sides are coalesced.
- val aqeReads = collect(smj.head) {
- case c: AQEShuffleReadExec if c.hasCoalescedPartition => c
- }
- assert(aqeReads.length == 2)
-
- // Repartition with default partition num specified.
- val dfRepartitionWithNum = df.repartition(5, 'b)
- dfRepartitionWithNum.collect()
- val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out.
- assert(hasRepartitionShuffle(planWithNum))
- val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
- assert(smjWithNum.length == 1)
- // Skew join can apply as the repartition is not optimized out.
- assert(smjWithNum.head.isSkewJoin)
- val aqeReadsWithNum = collect(smjWithNum.head) {
- case c: AQEShuffleReadExec => c
- }
- assert(aqeReadsWithNum.nonEmpty)
-
- // Repartition with default non-partition num specified.
- val dfRepartitionWithNum2 = df.repartition(3, 'b)
- dfRepartitionWithNum2.collect()
- val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out.
- assert(hasRepartitionShuffle(planWithNum2))
- val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
- assert(smjWithNum2.length == 1)
- // Skew join can apply as the repartition is not optimized out.
- assert(smjWithNum2.head.isSkewJoin)
+ checkSMJ(df.repartition('b),
+ // The top shuffle from repartition is optimized out.
+ optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkSMJ(df.repartition(5, 'b),
+ // The top shuffle from repartition is optimized out.
+ // The final plan must have 5 partitions, can't do coalesced read.
+ optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkSMJ(df.repartition(4, 'b),
+ // The top shuffle from repartition is not optimized out.
+ optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
+
+ // Repartition by col and project away the partition cols
+ checkSMJ(df.repartition('b).select('key),
+ // The top shuffle from repartition is not optimized out.
+ optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 8f7616c..0425be6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -21,16 +21,17 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
+class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)
+ private val EnsureRequirements = new EnsureRequirements()
+
test("reorder should handle PartitioningCollection") {
val plan1 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(
@@ -134,26 +135,4 @@ class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanH
}.size == 2)
}
}
-
- test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") {
- import testImplicits._
- Seq(true, false).foreach { enableAqe =>
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString,
- SQLConf.SHUFFLE_PARTITIONS.key -> "3",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val df1 = Seq((1, 2)).toDF("c1", "c2")
- val df2 = Seq((1, 3)).toDF("c3", "c4")
- val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1")
- val num = collect(res.queryExecution.executedPlan) {
- case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM =>
- shuffle
- }.size
- if (enableAqe) {
- assert(num == 1)
- } else {
- assert(num == 0)
- }
- }
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 92c38ee..83163cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -48,6 +48,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
protected var spark: SparkSession = null
+ private val EnsureRequirements = new EnsureRequirements()
+
/**
* Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled.
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 3588b9d..71e59ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructT
class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
+ private val EnsureRequirements = new EnsureRequirements()
+
private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 5262320..653049b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -33,6 +33,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder
+ private val EnsureRequirements = new EnsureRequirements()
+
private lazy val myUpperCaseData = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, "A"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 744ee1c..f704fdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
+ private val EnsureRequirements = new EnsureRequirements()
+
private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org