You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/04/23 04:43:41 UTC
[spark] branch master updated: [SPARK-35075][SQL] Add traversal
pruning for subquery related rules
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 47f8687 [SPARK-35075][SQL] Add traversal pruning for subquery related rules
47f8687 is described below
commit 47f86875f73edc3bec56d3610ab46a16ec37091c
Author: Yingyi Bu <yi...@databricks.com>
AuthorDate: Fri Apr 23 12:42:55 2021 +0800
[SPARK-35075][SQL] Add traversal pruning for subquery related rules
### What changes were proposed in this pull request?
Added the following TreePattern enums:
- DYNAMIC_PRUNING_SUBQUERY
- EXISTS_SUBQUERY
- IN_SUBQUERY
- LIST_SUBQUERY
- PLAN_EXPRESSION
- SCALAR_SUBQUERY
- FILTER
Used them in the following rules:
- ResolveSubquery
- UpdateOuterReferences
- OptimizeSubqueries
- RewritePredicateSubquery
- PullupCorrelatedPredicates
- RewriteCorrelatedScalarSubquery (not the rule itself but an internal transform call, the full support is in SPARK-35148)
- InsertAdaptiveSparkPlan
- PlanAdaptiveSubqueries
### Why are the changes needed?
Reduce the number of tree traversals and hence improve the query compilation latency.
### How was this patch tested?
Existing tests.
Closes #32247 from sigmod/subquery.
Authored-by: Yingyi Bu <yi...@databricks.com>
Signed-off-by: Gengliang Wang <lt...@gmail.com>
---
.../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 14 +++++++-------
.../spark/sql/catalyst/expressions/DynamicPruning.scala | 3 +++
.../spark/sql/catalyst/expressions/predicates.scala | 3 ++-
.../apache/spark/sql/catalyst/expressions/subquery.scala | 13 +++++++++++++
.../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 +++-
.../apache/spark/sql/catalyst/optimizer/subquery.scala | 15 ++++++++++-----
.../catalyst/plans/logical/basicLogicalOperators.scala | 4 +++-
.../spark/sql/catalyst/rules/RuleIdCollection.scala | 3 +++
.../apache/spark/sql/catalyst/trees/TreePatterns.scala | 7 +++++++
.../sql/execution/adaptive/InsertAdaptiveSparkPlan.scala | 4 ++++
.../sql/execution/adaptive/PlanAdaptiveSubqueries.scala | 5 ++++-
.../dynamicpruning/PlanDynamicPruningFilters.scala | 3 ++-
.../scala/org/apache/spark/sql/execution/subquery.scala | 3 ++-
13 files changed, 63 insertions(+), 18 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c2c146c..87b8d52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -39,9 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
- EXPRESSION_WITH_RANDOM_SEED, NATURAL_LIKE_JOIN, WINDOW_EXPRESSION
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -2179,7 +2177,8 @@ class Analyzer(override val catalogManager: CatalogManager)
* outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
- plan transformExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
+ EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
@@ -2196,7 +2195,8 @@ class Analyzer(override val catalogManager: CatalogManager)
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -3790,9 +3790,9 @@ object UpdateOuterReferences extends Rule[LogicalPlan] {
}
def apply(plan: LogicalPlan): LogicalPlan = {
- plan resolveOperators {
+ plan.resolveOperatorsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, FILTER), ruleId) {
case f @ Filter(_, a: Aggregate) if f.resolved =>
- f transformExpressions {
+ f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s: SubqueryExpression if s.children.nonEmpty =>
// Collect the aliases from output of aggregate.
val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index de4b874..1c185dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
trait DynamicPruning extends Predicate
@@ -69,6 +70,8 @@ case class DynamicPruningSubquery(
pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
}
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY)
+
override def toString: String = s"dynamicpruning#${exprId.id} $conditionString"
override lazy val canonicalized: DynamicPruning = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 34d1d8f..cb710ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, INSET, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -342,6 +342,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
values.head
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)
override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 2bedf84..ac939bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY,
+ PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet
@@ -38,6 +40,11 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
bits
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(PLAN_EXPRESSION) ++ nodePatternsInternal
+
+ // Subclasses can override this function to provide more TreePatterns.
+ def nodePatternsInternal(): Seq[TreePattern] = Seq()
+
/** The id of the subquery expression. */
def exprId: ExprId
@@ -247,6 +254,8 @@ case class ScalarSubquery(
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
}
object ScalarSubquery {
@@ -295,6 +304,8 @@ case class ListQuery(
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
}
/**
@@ -340,4 +351,6 @@ case class Exists(
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5343fce..09e7cff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -283,7 +284,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
case other => other
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(PLAN_EXPRESSION), ruleId) {
case s: SubqueryExpression =>
val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s))
// At this point we have an optimized subquery plan that we are going to attach
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 9381796..fa87894 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY,
+ LIST_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -94,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ t => t.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY) && t.containsPattern(FILTER)) {
case Filter(condition, child)
if SubqueryExpression.hasInOrCorrelatedExistsSubquery(condition) =>
val (withSubquery, withoutSubquery) =
@@ -164,7 +167,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
var newPlan = plan
val newExprs = exprs.map { e =>
- e transformDown {
+ e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan =
@@ -303,7 +306,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}
}
- plan transformExpressions {
+ plan.transformExpressionsWithPruning(_.containsAnyPattern(
+ SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
@@ -319,7 +323,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
/**
* Pull up the correlated predicates and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case f @ Filter(_, a: Aggregate) =>
rewriteSubQueries(f, Seq(a, a.child))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
@@ -341,7 +346,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
- val newExpression = expression transform {
+ val newExpression = expression.transformWithPruning(_.containsPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.plan.output.head
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 49e3e3c..bb999ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{
- INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
+ FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -166,6 +166,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
+ final override val nodePatterns: Seq[TreePattern] = Seq(FILTER)
+
override protected lazy val validConstraints: ExpressionSet = {
val predicates = splitConjunctivePredicates(condition)
.filterNot(SubqueryExpression.hasCorrelatedSubquery)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index d745f50..884e259 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -42,12 +42,15 @@ object RuleIdCollection {
// Catalyst Analyzer rules
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" ::
+ "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" ::
+ "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
+ "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 3dc1aff..7d725fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -24,15 +24,22 @@ object TreePattern extends Enumeration {
// Enum Ids start from 0.
// Expression patterns (alphabetically ordered)
val ATTRIBUTE_REFERENCE = Value(0)
+ val DYNAMIC_PRUNING_SUBQUERY: Value = Value
+ val EXISTS_SUBQUERY = Value
val EXPRESSION_WITH_RANDOM_SEED = Value
val IN: Value = Value
+ val IN_SUBQUERY: Value = Value
val INSET: Value = Value
+ val LIST_SUBQUERY: Value = Value
val LITERAL: Value = Value
val NULL_LITERAL: Value = Value
+ val PLAN_EXPRESSION: Value = Value
+ val SCALAR_SUBQUERY: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
// Logical plan patterns (alphabetically ordered)
+ val FILTER: Value = Value
val INNER_LIKE_JOIN: Value = Value
val JOIN: Value = Value
val LEFT_SEMI_OR_ANTI_JOIN: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index d98b7c2..1065519 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
@@ -113,6 +114,9 @@ case class InsertAdaptiveSparkPlan(
*/
private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = {
val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec]
+ if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
+ return subqueryMap.toMap
+ }
plan.foreach(_.expressions.foreach(_.foreach {
case expressions.ScalarSubquery(p, _, exprId)
if !subqueryMap.contains(exprId.id) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index 13ff236..a2e4397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
+ SCALAR_SUBQUERY}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan}
@@ -27,7 +29,8 @@ case class PlanAdaptiveSubqueries(
subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- plan.transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId) =>
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index d3bc4ae..9a05e39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY
import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins._
@@ -49,7 +50,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession)
return plan
}
- plan transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
case DynamicPruningSubquery(
value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
val sparkPlan = QueryExecution.createSparkPlan(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 15b8501..f96e9ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression,
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
@@ -176,7 +177,7 @@ case class InSubqueryExec(
*/
case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- plan.transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) {
case subquery: expressions.ScalarSubquery =>
val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
ScalarSubquery(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org