You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/04/23 04:43:41 UTC

[spark] branch master updated: [SPARK-35075][SQL] Add traversal pruning for subquery related rules

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 47f8687  [SPARK-35075][SQL] Add traversal pruning for subquery related rules
47f8687 is described below

commit 47f86875f73edc3bec56d3610ab46a16ec37091c
Author: Yingyi Bu <yi...@databricks.com>
AuthorDate: Fri Apr 23 12:42:55 2021 +0800

    [SPARK-35075][SQL] Add traversal pruning for subquery related rules
    
    ### What changes were proposed in this pull request?
    
    Added the following TreePattern enums:
    - DYNAMIC_PRUNING_SUBQUERY
    - EXISTS_SUBQUERY
    - IN_SUBQUERY
    - LIST_SUBQUERY
    - PLAN_EXPRESSION
    - SCALAR_SUBQUERY
    - FILTER
    
    Used them in the following rules:
    - ResolveSubquery
    - UpdateOuterReferences
    - OptimizeSubqueries
    - RewritePredicateSubquery
    - PullupCorrelatedPredicates
    - RewriteCorrelatedScalarSubquery (not the rule itself but an internal transform call, the full support is in SPARK-35148)
    - InsertAdaptiveSparkPlan
    - PlanAdaptiveSubqueries
    
    ### Why are the changes needed?
    
    Reduce the number of tree traversals and hence improve the query compilation latency.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #32247 from sigmod/subquery.
    
    Authored-by: Yingyi Bu <yi...@databricks.com>
    Signed-off-by: Gengliang Wang <lt...@gmail.com>
---
 .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 14 +++++++-------
 .../spark/sql/catalyst/expressions/DynamicPruning.scala   |  3 +++
 .../spark/sql/catalyst/expressions/predicates.scala       |  3 ++-
 .../apache/spark/sql/catalyst/expressions/subquery.scala  | 13 +++++++++++++
 .../apache/spark/sql/catalyst/optimizer/Optimizer.scala   |  4 +++-
 .../apache/spark/sql/catalyst/optimizer/subquery.scala    | 15 ++++++++++-----
 .../catalyst/plans/logical/basicLogicalOperators.scala    |  4 +++-
 .../spark/sql/catalyst/rules/RuleIdCollection.scala       |  3 +++
 .../apache/spark/sql/catalyst/trees/TreePatterns.scala    |  7 +++++++
 .../sql/execution/adaptive/InsertAdaptiveSparkPlan.scala  |  4 ++++
 .../sql/execution/adaptive/PlanAdaptiveSubqueries.scala   |  5 ++++-
 .../dynamicpruning/PlanDynamicPruningFilters.scala        |  3 ++-
 .../scala/org/apache/spark/sql/execution/subquery.scala   |  3 ++-
 13 files changed, 63 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c2c146c..87b8d52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -39,9 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
-  EXPRESSION_WITH_RANDOM_SEED, NATURAL_LIKE_JOIN, WINDOW_EXPRESSION
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
 import org.apache.spark.sql.connector.catalog._
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -2179,7 +2177,8 @@ class Analyzer(override val catalogManager: CatalogManager)
      *     outer plan to get evaluated.
      */
     private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
-      plan transformExpressions {
+      plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
+        EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
         case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
           resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, _, exprId) if !sub.resolved =>
@@ -2196,7 +2195,8 @@ class Analyzer(override val catalogManager: CatalogManager)
     /**
      * Resolve and rewrite all subqueries in an operator tree..
      */
-    def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+    def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
+      _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
       // In case of HAVING (a filter after an aggregate) we use both the aggregate and
       // its child for resolution.
       case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -3790,9 +3790,9 @@ object UpdateOuterReferences extends Rule[LogicalPlan] {
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = {
-    plan resolveOperators {
+    plan.resolveOperatorsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, FILTER), ruleId) {
       case f @ Filter(_, a: Aggregate) if f.resolved =>
-        f transformExpressions {
+        f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
           case s: SubqueryExpression if s.children.nonEmpty =>
             // Collect the aliases from output of aggregate.
             val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index de4b874..1c185dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, TreePattern}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 
 trait DynamicPruning extends Predicate
@@ -69,6 +70,8 @@ case class DynamicPruningSubquery(
       pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
   }
 
+  final override def nodePatternsInternal: Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY)
+
   override def toString: String = s"dynamicpruning#${exprId.id} $conditionString"
 
   override lazy val canonicalized: DynamicPruning = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 34d1d8f..cb710ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, INSET, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -342,6 +342,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
     values.head
   }
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (values.length != query.childOutputs.length) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 2bedf84..ac939bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY,
+  PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.BitSet
 
@@ -38,6 +40,11 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
     bits
   }
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(PLAN_EXPRESSION) ++ nodePatternsInternal
+
+  // Subclasses can override this function to provide more TreePatterns.
+  def nodePatternsInternal(): Seq[TreePattern] = Seq()
+
   /**  The id of the subquery expression. */
   def exprId: ExprId
 
@@ -247,6 +254,8 @@ case class ScalarSubquery(
 
   override protected def withNewChildrenInternal(
     newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
+
+  final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
 }
 
 object ScalarSubquery {
@@ -295,6 +304,8 @@ case class ListQuery(
 
   override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
     copy(children = newChildren)
+
+  final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
 }
 
 /**
@@ -340,4 +351,6 @@ case class Exists(
 
   override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
     copy(children = newChildren)
+
+  final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5343fce..09e7cff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -283,7 +284,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
         case other => other
       }
     }
-    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+      _.containsPattern(PLAN_EXPRESSION), ruleId) {
       case s: SubqueryExpression =>
         val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s))
         // At this point we have an optimized subquery plan that we are going to attach
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 9381796..fa87894 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY,
+  LIST_SUBQUERY, SCALAR_SUBQUERY}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -94,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    t => t.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY) && t.containsPattern(FILTER)) {
     case Filter(condition, child)
       if SubqueryExpression.hasInOrCorrelatedExistsSubquery(condition) =>
       val (withSubquery, withoutSubquery) =
@@ -164,7 +167,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
       plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
     var newPlan = plan
     val newExprs = exprs.map { e =>
-      e transformDown {
+      e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
         case Exists(sub, conditions, _) =>
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
           newPlan =
@@ -303,7 +306,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
       }
     }
 
-    plan transformExpressions {
+    plan.transformExpressionsWithPruning(_.containsAnyPattern(
+      SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
       case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
         val (newPlan, newCond) = decorrelate(sub, outerPlans)
         ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
@@ -319,7 +323,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
   /**
    * Pull up the correlated predicates and rewrite all subqueries in an operator tree..
    */
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
     case f @ Filter(_, a: Aggregate) =>
       rewriteSubQueries(f, Seq(a, a.child))
     // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
@@ -341,7 +346,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
   private def extractCorrelatedScalarSubqueries[E <: Expression](
       expression: E,
       subqueries: ArrayBuffer[ScalarSubquery]): E = {
-    val newExpression = expression transform {
+    val newExpression = expression.transformWithPruning(_.containsPattern(SCALAR_SUBQUERY)) {
       case s: ScalarSubquery if s.children.nonEmpty =>
         subqueries += s
         s.plan.output.head
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 49e3e3c..bb999ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.catalyst.trees.TreePattern.{
-  INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
+  FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
 }
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -166,6 +166,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
 
   override def maxRows: Option[Long] = child.maxRows
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(FILTER)
+
   override protected lazy val validConstraints: ExpressionSet = {
     val predicates = splitConjunctivePredicates(condition)
       .filterNot(SubqueryExpression.hasCorrelatedSubquery)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index d745f50..884e259 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -42,12 +42,15 @@ object RuleIdCollection {
       // Catalyst Analyzer rules
       "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" ::
       "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" ::
+      "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" ::
       "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" ::
       "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" ::
+      "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
       // Catalyst Optimizer rules
       "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
+      "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
       "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 3dc1aff..7d725fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -24,15 +24,22 @@ object TreePattern extends Enumeration  {
   // Enum Ids start from 0.
   // Expression patterns (alphabetically ordered)
   val ATTRIBUTE_REFERENCE = Value(0)
+  val DYNAMIC_PRUNING_SUBQUERY: Value = Value
+  val EXISTS_SUBQUERY = Value
   val EXPRESSION_WITH_RANDOM_SEED = Value
   val IN: Value = Value
+  val IN_SUBQUERY: Value = Value
   val INSET: Value = Value
+  val LIST_SUBQUERY: Value = Value
   val LITERAL: Value = Value
   val NULL_LITERAL: Value = Value
+  val PLAN_EXPRESSION: Value = Value
+  val SCALAR_SUBQUERY: Value = Value
   val TRUE_OR_FALSE_LITERAL: Value = Value
   val WINDOW_EXPRESSION: Value = Value
 
   // Logical plan patterns (alphabetically ordered)
+  val FILTER: Value = Value
   val INNER_LIKE_JOIN: Value = Value
   val JOIN: Value = Value
   val LEFT_SEMI_OR_ANTI_JOIN: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index d98b7c2..1065519 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
 import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
@@ -113,6 +114,9 @@ case class InsertAdaptiveSparkPlan(
    */
   private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = {
     val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec]
+    if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
+      return subqueryMap.toMap
+    }
     plan.foreach(_.expressions.foreach(_.foreach {
       case expressions.ScalarSubquery(p, _, exprId)
           if !subqueryMap.contains(exprId.id) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index 13ff236..a2e4397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal}
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
+  SCALAR_SUBQUERY}
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan}
 
@@ -27,7 +29,8 @@ case class PlanAdaptiveSubqueries(
     subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] {
 
   def apply(plan: SparkPlan): SparkPlan = {
-    plan.transformAllExpressions {
+    plan.transformAllExpressionsWithPruning(
+      _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
       case expressions.ScalarSubquery(_, _, exprId) =>
         execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
       case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index d3bc4ae..9a05e39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
 import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY
 import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec}
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins._
@@ -49,7 +50,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession)
       return plan
     }
 
-    plan transformAllExpressions {
+    plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
       case DynamicPruningSubquery(
           value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
         val sparkPlan = QueryExecution.createSparkPlan(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 15b8501..f96e9ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression,
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
 
@@ -176,7 +177,7 @@ case class InSubqueryExec(
  */
 case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
   def apply(plan: SparkPlan): SparkPlan = {
-    plan.transformAllExpressions {
+    plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) {
       case subquery: expressions.ScalarSubquery =>
         val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
         ScalarSubquery(

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org