You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/08/03 10:30:20 UTC

[spark] branch branch-3.2 updated: [SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 8d817dc  [SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement
8d817dc is described below

commit 8d817dcf3084d56da22b909d578a644143f775d5
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Aug 3 18:28:52 2021 +0800

    [SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/30494
    
    This PR proposes a new way to optimize the final query stage in AQE. We first collect the effective user-specified repartition (semantic-wise, user-specified repartition is only effective if it's the root node or under a few simple nodes), and get the required distribution for the final plan. When we optimize the final query stage, we skip certain `AQEShuffleReadRule` if it breaks the required distribution.
    
    ### Why are the changes needed?
    
    The current solution for optimizing the final query stage is pretty hacky and overkill. As an example, the newly added rule `OptimizeSkewInRebalancePartitions` can hardly apply as it's very common that the query plan has shuffles with origin `ENSURE_REQUIREMENTS`, which is not supported by `OptimizeSkewInRebalancePartitions`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    updated tests
    
    Closes #33541 from cloud-fan/aqe.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit dd80457ffb1c129a1ca3c53bcf3ea5feed7ebc57)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/execution/QueryExecution.scala       |   2 +-
 .../execution/adaptive/AQEShuffleReadExec.scala    |  36 ++++-
 .../execution/adaptive/AQEShuffleReadRule.scala    |  12 +-
 .../spark/sql/execution/adaptive/AQEUtils.scala    |  60 ++++++++
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  63 ++++----
 .../adaptive/CoalesceShufflePartitions.scala       |  10 +-
 .../adaptive/OptimizeShuffleWithLocalRead.scala    |  12 +-
 .../OptimizeSkewInRebalancePartitions.scala        |   8 +-
 .../execution/adaptive/OptimizeSkewedJoin.scala    |  15 +-
 .../execution/exchange/EnsureRequirements.scala    |  16 +-
 .../execution/exchange/ValidateRequirements.scala  |   4 +
 .../apache/spark/sql/execution/PlannerSuite.scala  |   2 +
 .../adaptive/AdaptiveQueryExecSuite.scala          | 170 +++++++++++----------
 .../exchange/EnsureRequirementsSuite.scala         |  27 +---
 .../sql/execution/joins/BroadcastJoinSuite.scala   |   2 +
 .../sql/execution/joins/ExistenceJoinSuite.scala   |   2 +
 .../spark/sql/execution/joins/InnerJoinSuite.scala |   2 +
 .../spark/sql/execution/joins/OuterJoinSuite.scala |   2 +
 18 files changed, 274 insertions(+), 171 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 5a654ad..6c16dce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -409,7 +409,7 @@ object QueryExecution {
       PlanDynamicPruningFilters(sparkSession),
       PlanSubqueries(sparkSession),
       RemoveRedundantProjects,
-      EnsureRequirements,
+      EnsureRequirements(),
       // `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
       // number of partitions when instantiating PartitioningCollection.
       RemoveRedundantSorts,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index 0768b9b..af62157 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -22,13 +22,13 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
-
 /**
  * A wrapper of shuffle query stage, which follows the given partition arrangement.
  *
@@ -51,6 +51,7 @@ case class AQEShuffleReadExec private(
   override def supportsColumnar: Boolean = child.supportsColumnar
 
   override def output: Seq[Attribute] = child.output
+
   override lazy val outputPartitioning: Partitioning = {
     // If it is a local shuffle read with one mapper per task, then the output partitioning is
     // the same as the plan before shuffle.
@@ -69,6 +70,21 @@ case class AQEShuffleReadExec private(
         case _ =>
           throw new IllegalStateException("operating on canonicalization plan")
       }
+    } else if (isCoalescedRead) {
+      // For coalesced shuffle read, the data distribution is not changed, only the number of
+      // partitions is changed.
+      child.outputPartitioning match {
+        case h: HashPartitioning =>
+          CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length))
+        case r: RangePartitioning =>
+          CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length))
+        // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
+        // `RoundRobinPartitioning` but we don't need to retain the number of partitions.
+        case r: RoundRobinPartitioning =>
+          r.copy(numPartitions = partitionSpecs.length)
+        case other => throw new IllegalStateException(
+          "Unexpected partitioning for coalesced shuffle read: " + other)
+      }
     } else {
       UnknownPartitioning(partitionSpecs.length)
     }
@@ -92,7 +108,7 @@ case class AQEShuffleReadExec private(
   /**
    * Returns true iff some partitions were actually combined
    */
-  private def isCoalesced(spec: ShufflePartitionSpec) = spec match {
+  private def isCoalescedSpec(spec: ShufflePartitionSpec) = spec match {
     case CoalescedPartitionSpec(0, 0, _) => true
     case s: CoalescedPartitionSpec => s.endReducerIndex - s.startReducerIndex > 1
     case _ => false
@@ -102,7 +118,7 @@ case class AQEShuffleReadExec private(
    * Returns true iff some non-empty partitions were combined
    */
   def hasCoalescedPartition: Boolean = {
-    partitionSpecs.exists(isCoalesced)
+    partitionSpecs.exists(isCoalescedSpec)
   }
 
   def hasSkewedPartition: Boolean =
@@ -112,6 +128,16 @@ case class AQEShuffleReadExec private(
     partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) ||
       partitionSpecs.exists(_.isInstanceOf[CoalescedMapperPartitionSpec])
 
+  def isCoalescedRead: Boolean = {
+    partitionSpecs.sliding(2).forall {
+      // A single partition spec which is `CoalescedPartitionSpec` also means coalesced read.
+      case Seq(_: CoalescedPartitionSpec) => true
+      case Seq(l: CoalescedPartitionSpec, r: CoalescedPartitionSpec) =>
+        l.endReducerIndex <= r.startReducerIndex
+      case _ => false
+    }
+  }
+
   private def shuffleStage = child match {
     case stage: ShuffleQueryStageExec => Some(stage)
     case _ => None
@@ -159,7 +185,7 @@ case class AQEShuffleReadExec private(
 
     if (hasCoalescedPartition) {
       val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions")
-      val x = partitionSpecs.count(isCoalesced)
+      val x = partitionSpecs.count(isCoalescedSpec)
       numCoalescedPartitionsMetric.set(x)
       driverAccumUpdates += numCoalescedPartitionsMetric.id -> x
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
index 1c7f2ea..c303e85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala
@@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.adaptive
 
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.ShuffleOrigin
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
 
 /**
- * Adaptive Query Execution rule that may create [[AQEShuffleReadExec]] on top of query stages.
+ * A rule that may create [[AQEShuffleReadExec]] on top of [[ShuffleQueryStageExec]] and change the
+ * plan output partitioning. The AQE framework will skip the rule if it leads to extra shuffles.
  */
 trait AQEShuffleReadRule extends Rule[SparkPlan] {
-
   /**
    * Returns the list of [[ShuffleOrigin]]s supported by this rule.
    */
-  def supportedShuffleOrigins: Seq[ShuffleOrigin]
+  protected def supportedShuffleOrigins: Seq[ShuffleOrigin]
 
-  def mayAddExtraShuffles: Boolean = false
+  protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+    supportedShuffleOrigins.contains(shuffle.shuffleOrigin)
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala
new file mode 100644
index 0000000..277af21
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, HashPartitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{CollectMetricsExec, FilterExec, ProjectExec, SortExec, SparkPlan}
+import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
+
+object AQEUtils {
+
+  // Analyze the given plan and calculate the required distribution of this plan w.r.t. the
+  // user-specified repartition.
+  def getRequiredDistribution(p: SparkPlan): Option[Distribution] = p match {
+    // User-specified repartition is only effective when it's the root node, or under
+    // Project/Filter/LocalSort/CollectMetrics.
+    // Note: we only care about `HashPartitioning` as `EnsureRequirements` can only optimize out
+    // user-specified repartition with `HashPartitioning`.
+    case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin)
+        if shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM =>
+      val numPartitions = if (shuffleOrigin == REPARTITION_BY_NUM) {
+        Some(h.numPartitions)
+      } else {
+        None
+      }
+      Some(HashClusteredDistribution(h.expressions, numPartitions))
+    case f: FilterExec => getRequiredDistribution(f.child)
+    case s: SortExec if !s.global => getRequiredDistribution(s.child)
+    case c: CollectMetricsExec => getRequiredDistribution(c.child)
+    case p: ProjectExec =>
+      getRequiredDistribution(p.child).flatMap {
+        case h: HashClusteredDistribution =>
+          if (h.expressions.forall(e => p.projectList.exists(_.semanticEquals(e)))) {
+            Some(h)
+          } else {
+            // It's possible that the user-specified repartition is effective but the output
+            // partitioning is not retained, e.g. `df.repartition(a, b).select(c)`. We can't
+            // handle this case with required distribution. Here we return None and later on
+            // `EnsureRequirements` will skip optimizing out the user-specified repartition.
+            None
+          }
+        case other => Some(other)
+      }
+    case _ => Some(UnspecifiedDistribution)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index c03bb4b..9db4574 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
 import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -83,12 +84,28 @@ case class AdaptiveSparkPlanExec(
   // The logical plan optimizer for re-optimizing the current logical plan.
   @transient private val optimizer = new AQEOptimizer(conf)
 
+  // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
+  // change its output partitioning. This assumption is not true in AQE. Here we check the
+  // `inputPlan` which has not been processed by `EnsureRequirements` yet, to find out the
+  // effective user-specified repartition. Later on, the AQE framework will make sure the final
+  // output partitioning is not changed w.r.t the effective user-specified repartition.
+  @transient private val requiredDistribution: Option[Distribution] = if (isSubquery) {
+    // Subquery output does not need a specific output partitioning.
+    Some(UnspecifiedDistribution)
+  } else {
+    AQEUtils.getRequiredDistribution(inputPlan)
+  }
+
   // A list of physical plan rules to be applied before creation of query stages. The physical
   // plan should reach a final status of query stages (i.e., no more addition or removal of
   // Exchange nodes) after running these rules.
   private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
     RemoveRedundantProjects,
-    EnsureRequirements,
+    // For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for
+    // the final plan, but we do need to respect the user-specified repartition. Here we ask
+    // `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
+    // around this case.
+    EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
     RemoveRedundantSorts,
     DisableUnnecessaryBucketedScan
   ) ++ context.session.sessionState.queryStagePrepRules
@@ -114,33 +131,24 @@ case class AdaptiveSparkPlanExec(
     CollapseCodegenStages()
   ) ++ context.session.sessionState.postStageCreationRules
 
-  // The partitioning of the query output depends on the shuffle(s) in the final stage. If the
-  // original plan contains a repartition operator, we need to preserve the specified partitioning,
-  // whether or not the repartition-introduced shuffle is optimized out because of an underlying
-  // shuffle of the same partitioning. Thus, we need to exclude some `AQEShuffleReadRule`s
-  // from the final stage, depending on the presence and properties of repartition operators.
-  private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = {
-    val origins = inputPlan.collect {
-      case s: ShuffleExchangeLike => s.shuffleOrigin
-    }
-    val allRules = queryStageOptimizerRules ++ postStageCreationRules
-    allRules.filter {
-      case c: AQEShuffleReadRule =>
-        origins.forall(c.supportedShuffleOrigins.contains)
-      case _ => true
-    }
-  }
-
-  private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
-    val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
+  private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
+    val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
       val applied = rule.apply(latestPlan)
       val result = rule match {
-        case c: AQEShuffleReadRule if c.mayAddExtraShuffles =>
-          if (ValidateRequirements.validate(applied)) {
+        case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
+          val distribution = if (isFinalStage) {
+            // If `requiredDistribution` is None, it means `EnsureRequirements` will not optimize
+            // out the user-specified repartition, thus we don't have a distribution requirement
+            // for the final plan.
+            requiredDistribution.getOrElse(UnspecifiedDistribution)
+          } else {
+            UnspecifiedDistribution
+          }
+          if (ValidateRequirements.validate(applied, distribution)) {
             applied
           } else {
-            logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " +
-              "will be introduced.")
+            logDebug(s"Rule ${rule.ruleName} is not applied as it breaks the " +
+              "distribution requirement of the query plan.")
             latestPlan
           }
         case _ => applied
@@ -303,7 +311,10 @@ case class AdaptiveSparkPlanExec(
       }
 
       // Run the final plan when there's no more unfinished stages.
-      currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules)
+      currentPhysicalPlan = applyPhysicalRules(
+        optimizeQueryStage(result.newPlan, isFinalStage = true),
+        postStageCreationRules,
+        Some((planChangeLogger, "AQE Post Stage Creation")))
       isFinalPlan = true
       executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
       currentPhysicalPlan
@@ -520,7 +531,7 @@ case class AdaptiveSparkPlanExec(
   }
 
   private def newQueryStage(e: Exchange): QueryStageExec = {
-    val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules)
+    val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false)
     val queryStage = e match {
       case s: ShuffleExchangeLike =>
         val newShuffle = applyPhysicalRules(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
index 7f3e453..75c53b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
@@ -33,6 +33,10 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
     Seq(ENSURE_REQUIREMENTS, REPARTITION_BY_COL, REBALANCE_PARTITIONS_BY_NONE,
       REBALANCE_PARTITIONS_BY_COL)
 
+  override def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+    shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
+  }
+
   override def apply(plan: SparkPlan): SparkPlan = {
     if (!conf.coalesceShufflePartitionsEnabled) {
       return plan
@@ -52,7 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
     val shuffleStageInfos = collectShuffleStageInfos(plan)
     // ShuffleExchanges introduced by repartition do not support changing the number of partitions.
     // We change the number of partitions in the stage only if all the ShuffleExchanges support it.
-    if (!shuffleStageInfos.forall(s => supportCoalesce(s.shuffleStage.shuffle))) {
+    if (!shuffleStageInfos.forall(s => isSupported(s.shuffleStage.shuffle))) {
       plan
     } else {
       // Ideally, this rule should simply coalesce partition w.r.t. the target size specified by
@@ -106,10 +110,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
       }.getOrElse(plan)
     case other => other.mapChildren(updateShuffleReads(_, specsMap))
   }
-
-  private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
-    s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
-  }
 }
 
 private class ShuffleStageInfo(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
index 844acbd..cf1c7ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala
@@ -38,7 +38,9 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
   override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
     Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE)
 
-  override def mayAddExtraShuffles: Boolean = true
+  override protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
+    shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
+  }
 
   // The build side is a broadcast query stage which should have been optimized using local read
   // already. So we only need to deal with probe side here.
@@ -136,14 +138,10 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
 
   def canUseLocalShuffleRead(plan: SparkPlan): Boolean = plan match {
     case s: ShuffleQueryStageExec =>
-      s.mapStats.isDefined && supportLocalRead(s.shuffle)
+      s.mapStats.isDefined && isSupported(s.shuffle)
     case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) =>
-      s.mapStats.isDefined && supportLocalRead(s.shuffle) &&
+      s.mapStats.isDefined && isSupported(s.shuffle) &&
         s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS
     case _ => false
   }
-
-  private def supportLocalRead(s: ShuffleExchangeLike): Boolean = {
-    s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
index dc437403..1752907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
@@ -38,7 +38,8 @@ import org.apache.spark.sql.internal.SQLConf
  * ShuffleQueryStageExec.
  */
 object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
-  override def supportedShuffleOrigins: Seq[ShuffleOrigin] =
+
+  override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
     Seq(REBALANCE_PARTITIONS_BY_NONE, REBALANCE_PARTITIONS_BY_COL)
 
   /**
@@ -92,9 +93,8 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
     }
 
     plan match {
-      case shuffle: ShuffleQueryStageExec
-          if supportedShuffleOrigins.contains(shuffle.shuffle.shuffleOrigin) =>
-        tryOptimizeSkewedPartitions(shuffle)
+      case stage: ShuffleQueryStageExec if isSupported(stage.shuffle) =>
+        tryOptimizeSkewedPartitions(stage)
       case _ => plan
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index fbfbce6..88abe68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -52,8 +52,6 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
 
   override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
 
-  override def mayAddExtraShuffles: Boolean = true
-
   /**
    * A partition is considered as a skewed partition if its size is larger than the median
    * partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than
@@ -257,13 +255,12 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
       plan
     }
   }
-}
 
-private object ShuffleStage {
-  def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
-    case s: ShuffleQueryStageExec if s.mapStats.isDefined &&
-        OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
-      Some(s)
-    case _ => None
+  object ShuffleStage {
+    def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
+      case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) =>
+        Some(s)
+      case _ => None
+    }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index d71933a..23716f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -32,8 +32,14 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
  * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
  * each operator by inserting [[ShuffleExchangeExec]] Operators where required.  Also ensure that
  * the input partition ordering requirements are met.
+ *
+ * @param optimizeOutRepartition A flag to indicate that if this rule should optimize out
+ *                               user-specified repartition shuffles or not. This is mostly true,
+ *                               but can be false in AQE when AQE optimization may change the plan
+ *                               output partitioning and need to retain the user-specified
+ *                               repartition shuffles in the plan.
  */
-object EnsureRequirements extends Rule[SparkPlan] {
+case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] {
 
   private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
     val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
@@ -249,13 +255,9 @@ object EnsureRequirements extends Rule[SparkPlan] {
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    // TODO: remove this after we create a physical operator for `RepartitionByExpression`.
-    // SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM
-    // shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL,
-    // it is a special shuffle used to rebalance partitions.
-    // So, here we only remove REPARTITION_BY_COL in AQE.
     case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
-        if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled =>
+        if optimizeOutRepartition &&
+          (shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
       def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
         partitioning match {
           case lower: HashPartitioning if upper.semanticEquals(lower) => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
index 6964d9c..5003db6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala
@@ -30,6 +30,10 @@ import org.apache.spark.sql.execution._
  */
 object ValidateRequirements extends Logging {
 
+  def validate(plan: SparkPlan, requiredDistribution: Distribution): Boolean = {
+    validate(plan) && plan.outputPartitioning.satisfies(requiredDistribution)
+  }
+
   def validate(plan: SparkPlan): Boolean = {
     plan.children.forall(validate) && validateInternal(plan)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index fad6ed1..df310cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -40,6 +40,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
 
   setupTestData()
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
     val planner = spark.sessionState.planner
     import planner._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index dda94f1..ca8295e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1496,6 +1496,60 @@ class AdaptiveQueryExecSuite
       }.isDefined
     }
 
+    def checkBHJ(
+        df: Dataset[Row],
+        optimizeOutRepartition: Boolean,
+        probeSideLocalRead: Boolean,
+        probeSideCoalescedRead: Boolean): Unit = {
+      df.collect()
+      val plan = df.queryExecution.executedPlan
+      // There should be only one shuffle that can't do local read, which is either the top shuffle
+      // from repartition, or BHJ probe side shuffle.
+      checkNumLocalShuffleReads(plan, 1)
+      assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+      val bhj = findTopLevelBroadcastHashJoin(plan)
+      assert(bhj.length == 1)
+
+      // Build side should do local read.
+      val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec])
+      assert(buildSide.isDefined)
+      assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+
+      val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec])
+      if (probeSideLocalRead || probeSideCoalescedRead) {
+        assert(probeSide.isDefined)
+        if (probeSideLocalRead) {
+          assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+        } else {
+          assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
+        }
+      } else {
+        assert(probeSide.isEmpty)
+      }
+    }
+
+    def checkSMJ(
+        df: Dataset[Row],
+        optimizeOutRepartition: Boolean,
+        optimizeSkewJoin: Boolean,
+        coalescedRead: Boolean): Unit = {
+      df.collect()
+      val plan = df.queryExecution.executedPlan
+      assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.length == 1)
+      assert(smj.head.isSkewJoin == optimizeSkewJoin)
+      val aqeReads = collect(smj.head) {
+        case c: AQEShuffleReadExec => c
+      }
+      if (coalescedRead || optimizeSkewJoin) {
+        assert(aqeReads.length == 2)
+        if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition))
+      } else {
+        assert(aqeReads.isEmpty)
+      }
+    }
+
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
       SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
       val df = sql(
@@ -1509,44 +1563,25 @@ class AdaptiveQueryExecSuite
 
       withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
         // Repartition with no partition num specified.
-        val dfRepartition = df.repartition('b)
-        dfRepartition.collect()
-        val plan = dfRepartition.queryExecution.executedPlan
-        // The top shuffle from repartition is optimized out.
-        assert(!hasRepartitionShuffle(plan))
-        val bhj = findTopLevelBroadcastHashJoin(plan)
-        assert(bhj.length == 1)
-        checkNumLocalShuffleReads(plan, 1)
-        // Probe side is coalesced.
-        val aqeRead = bhj.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
-        assert(aqeRead.isDefined)
-        assert(aqeRead.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
-
-        // Repartition with partition default num specified.
-        val dfRepartitionWithNum = df.repartition(5, 'b)
-        dfRepartitionWithNum.collect()
-        val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
-        // The top shuffle from repartition is not optimized out.
-        assert(hasRepartitionShuffle(planWithNum))
-        val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
-        assert(bhjWithNum.length == 1)
-        checkNumLocalShuffleReads(planWithNum, 1)
-        // Probe side is coalesced.
-        assert(bhjWithNum.head.right.find(_.isInstanceOf[AQEShuffleReadExec]).nonEmpty)
-
-        // Repartition with partition non-default num specified.
-        val dfRepartitionWithNum2 = df.repartition(3, 'b)
-        dfRepartitionWithNum2.collect()
-        val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
-        // The top shuffle from repartition is not optimized out, and this is the only shuffle that
-        // does not have local shuffle read.
-        assert(hasRepartitionShuffle(planWithNum2))
-        val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
-        assert(bhjWithNum2.length == 1)
-        checkNumLocalShuffleReads(planWithNum2, 1)
-        val aqeRead2 = bhjWithNum2.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
-        assert(aqeRead2.isDefined)
-        assert(aqeRead2.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+        checkBHJ(df.repartition('b),
+          // The top shuffle from repartition is optimized out.
+          optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true)
+
+        // Repartition with default partition num (5 in test env) specified.
+        checkBHJ(df.repartition(5, 'b),
+          // The top shuffle from repartition is optimized out
+          // The final plan must have 5 partitions, no optimization can be made to the probe side.
+          optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false)
+
+        // Repartition with non-default partition num specified.
+        checkBHJ(df.repartition(4, 'b),
+          // The top shuffle from repartition is not optimized out
+          optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
+
+        // Repartition by col and project away the partition cols
+        checkBHJ(df.repartition('b).select('key),
+          // The top shuffle from repartition is not optimized out
+          optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
       }
 
       // Force skew join
@@ -1556,46 +1591,25 @@ class AdaptiveQueryExecSuite
         SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
         SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
         // Repartition with no partition num specified.
-        val dfRepartition = df.repartition('b)
-        dfRepartition.collect()
-        val plan = dfRepartition.queryExecution.executedPlan
-        // The top shuffle from repartition is optimized out.
-        assert(!hasRepartitionShuffle(plan))
-        val smj = findTopLevelSortMergeJoin(plan)
-        assert(smj.length == 1)
-        // No skew join due to the repartition.
-        assert(!smj.head.isSkewJoin)
-        // Both sides are coalesced.
-        val aqeReads = collect(smj.head) {
-          case c: AQEShuffleReadExec if c.hasCoalescedPartition => c
-        }
-        assert(aqeReads.length == 2)
-
-        // Repartition with default partition num specified.
-        val dfRepartitionWithNum = df.repartition(5, 'b)
-        dfRepartitionWithNum.collect()
-        val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
-        // The top shuffle from repartition is not optimized out.
-        assert(hasRepartitionShuffle(planWithNum))
-        val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
-        assert(smjWithNum.length == 1)
-        // Skew join can apply as the repartition is not optimized out.
-        assert(smjWithNum.head.isSkewJoin)
-        val aqeReadsWithNum = collect(smjWithNum.head) {
-          case c: AQEShuffleReadExec => c
-        }
-        assert(aqeReadsWithNum.nonEmpty)
-
-        // Repartition with default non-partition num specified.
-        val dfRepartitionWithNum2 = df.repartition(3, 'b)
-        dfRepartitionWithNum2.collect()
-        val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
-        // The top shuffle from repartition is not optimized out.
-        assert(hasRepartitionShuffle(planWithNum2))
-        val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
-        assert(smjWithNum2.length == 1)
-        // Skew join can apply as the repartition is not optimized out.
-        assert(smjWithNum2.head.isSkewJoin)
+        checkSMJ(df.repartition('b),
+          // The top shuffle from repartition is optimized out.
+          optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true)
+
+        // Repartition with default partition num (5 in test env) specified.
+        checkSMJ(df.repartition(5, 'b),
+          // The top shuffle from repartition is optimized out.
+          // The final plan must have 5 partitions, can't do coalesced read.
+          optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false)
+
+        // Repartition with non-default partition num specified.
+        checkSMJ(df.repartition(4, 'b),
+          // The top shuffle from repartition is not optimized out.
+          optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
+
+        // Repartition by col and project away the partition cols
+        checkSMJ(df.repartition('b).select('key),
+          // The top shuffle from repartition is not optimized out.
+          optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
       }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 8f7616c..0425be6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -21,16 +21,17 @@ import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
 import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 
-class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
+class EnsureRequirementsSuite extends SharedSparkSession {
   private val exprA = Literal(1)
   private val exprB = Literal(2)
   private val exprC = Literal(3)
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   test("reorder should handle PartitioningCollection") {
     val plan1 = DummySparkPlan(
       outputPartitioning = PartitioningCollection(Seq(
@@ -134,26 +135,4 @@ class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanH
       }.size == 2)
     }
   }
-
-  test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") {
-    import testImplicits._
-    Seq(true, false).foreach { enableAqe =>
-      withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString,
-          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
-          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
-        val df1 = Seq((1, 2)).toDF("c1", "c2")
-        val df2 = Seq((1, 3)).toDF("c3", "c4")
-        val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1")
-        val num = collect(res.queryExecution.executedPlan) {
-          case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM =>
-            shuffle
-        }.size
-        if (enableAqe) {
-          assert(num == 1)
-        } else {
-          assert(num == 0)
-        }
-      }
-    }
-  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 92c38ee..83163cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -48,6 +48,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
 
   protected var spark: SparkSession = null
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   /**
    * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled.
    */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 3588b9d..71e59ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructT
 
 class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   private lazy val left = spark.createDataFrame(
     sparkContext.parallelize(Seq(
       Row(1, 2.0),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 5262320..653049b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -33,6 +33,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
   import testImplicits.newProductEncoder
   import testImplicits.localSeqToDatasetHolder
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   private lazy val myUpperCaseData = spark.createDataFrame(
     sparkContext.parallelize(Seq(
       Row(1, "A"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 744ee1c..f704fdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
 
 class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
 
+  private val EnsureRequirements = new EnsureRequirements()
+
   private lazy val left = spark.createDataFrame(
     sparkContext.parallelize(Seq(
       Row(1, 2.0),

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org