You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/04/23 08:34:34 UTC

[spark] branch master updated: [SPARK-35078][SQL] Add tree traversal pruning in expression rules

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9af338c  [SPARK-35078][SQL] Add tree traversal pruning in expression rules
9af338c is described below

commit 9af338cd685bce26abbc2dd4d077bde5068157b1
Author: Yingyi Bu <yi...@databricks.com>
AuthorDate: Fri Apr 23 16:33:58 2021 +0800

    [SPARK-35078][SQL] Add tree traversal pruning in expression rules
    
    ### What changes were proposed in this pull request?
    
    Added the following TreePattern enums:
    - AND_OR
    - BINARY_ARITHMETIC
    - BINARY_COMPARISON
    - CASE_WHEN
    - CAST
    - CONCAT
    - COUNT
    - IF
    - LIKE_FAMLIY
    - NOT
    - NULL_CHECK
    - UNARY_POSITIVE
    - UPPER_OR_LOWER
    
    Used them in the following rules:
    - ConstantPropagation
    - ReorderAssociativeOperator
    - BooleanSimplification
    - SimplifyBinaryComparison
    - SimplifyCaseConversionExpressions
    - SimplifyConditionals
    - PushFoldableIntoBranches
    - LikeSimplification
    - NullPropagation
    - SimplifyCasts
    - RemoveDispensableExpressions
    - CombineConcats
    
    ### Why are the changes needed?
    
    Reduce the number of tree traversals and hence improve the query compilation latency.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #32280 from sigmod/expression.
    
    Authored-by: Yingyi Bu <yi...@databricks.com>
    Signed-off-by: Gengliang Wang <lt...@gmail.com>
---
 .../spark/sql/catalyst/expressions/Cast.scala      |  3 ++
 .../sql/catalyst/expressions/aggregate/Count.scala |  3 ++
 .../sql/catalyst/expressions/arithmetic.scala      |  6 +++
 .../expressions/collectionOperations.scala         |  3 ++
 .../expressions/conditionalExpressions.scala       |  5 ++
 .../sql/catalyst/expressions/nullExpressions.scala |  5 ++
 .../sql/catalyst/expressions/objects/objects.scala |  3 ++
 .../sql/catalyst/expressions/predicates.scala      | 14 +++--
 .../catalyst/expressions/regexpExpressions.scala   |  5 ++
 .../catalyst/expressions/stringExpressions.scala   |  5 ++
 .../spark/sql/catalyst/optimizer/expressions.scala | 62 ++++++++++++++--------
 .../plans/logical/basicLogicalOperators.scala      |  4 +-
 .../sql/catalyst/rules/RuleIdCollection.scala      | 15 +++++-
 .../spark/sql/catalyst/trees/TreePatterns.scala    | 17 +++++-
 14 files changed, 120 insertions(+), 30 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 5d799c7..30317c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
     copy(timeZoneId = Option(timeZoneId))
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(CAST)
+
   override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
 
   override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1d13155..dfdd828 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
 
   override def nullable: Boolean = false
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)
+
   // Return data type.
   override def dataType: DataType = LongType
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 954a4b9..10b4a7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
+  UNARY_POSITIVE}
 import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)
 
   override def dataType: DataType = child.dataType
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
     defineCodeGen(ctx, ev, c => c)
 
@@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
 
   override def dataType: DataType = left.dataType
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)
+
   override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
 
   /** Name of the function for this expression on a [[Decimal]] type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 125e796..57bfbcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
 import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
 
   private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.isEmpty) {
       TypeCheckResult.TypeCheckSuccess
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index e708d56..3e356f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
 import org.apache.spark.sql.types._
 
 // scalastyle:off line.size.limit
@@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
   override def third: Expression = falseValue
   override def nullable: Boolean = trueValue.nullable || falseValue.nullable
 
+  final override val nodePatterns : Seq[TreePattern] = Seq(IF)
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (predicate.dataType != BooleanType) {
       TypeCheckResult.TypeCheckFailure(
@@ -139,6 +142,8 @@ case class CaseWhen(
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
+  final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)
+
   override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
     super.legacyWithNewChildren(newChildren)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 2c2df6b..d4a02c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
@@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
 case class IsNull(child: Expression) extends UnaryExpression with Predicate {
   override def nullable: Boolean = false
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
   override def eval(input: InternalRow): Any = {
     child.eval(input) == null
   }
@@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
 case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
   override def nullable: Boolean = false
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
   override def eval(input: InternalRow): Any = {
     child.eval(input) != null
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5ae0cef..a17ac20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
   override def foldable: Boolean = false
   override def nullable: Boolean = false
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
   override def flatArguments: Iterator[Any] = Iterator(child)
 
   private val errMsg = "Null value appeared in non-nullable field:" +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index cb710ad..4885f77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -309,6 +309,8 @@ case class Not(child: Expression)
 
   override def inputTypes: Seq[DataType] = Seq(BooleanType)
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(NOT)
+
   // +---------+-----------+
   // | CHILD   | NOT CHILD |
   // +---------+-----------+
@@ -435,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
 
-  override val nodePatterns: Seq[TreePattern] = Seq(IN)
+  final override val nodePatterns: Seq[TreePattern] = Seq(IN)
 
   override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
 
@@ -548,7 +550,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
 
   override def nullable: Boolean = child.nullable || hasNull
 
-  override val nodePatterns: Seq[TreePattern] = Seq(INSET)
+  final override val nodePatterns: Seq[TreePattern] = Seq(INSET)
 
   protected override def nullSafeEval(value: Any): Any = {
     if (set.contains(value)) {
@@ -666,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
 
   override def sqlOperator: String = "AND"
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
   // +---------+---------+---------+---------+
   // | AND     | TRUE    | FALSE   | UNKNOWN |
   // +---------+---------+---------+---------+
@@ -752,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
 
   override def sqlOperator: String = "OR"
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
   // +---------+---------+---------+---------+
   // | OR      | TRUE    | FALSE   | UNKNOWN |
   // +---------+---------+---------+---------+
@@ -823,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
   // finitely enumerable. The allowable types are checked below by checkInputDataTypes.
   override def inputType: AbstractDataType = AnyDataType
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)
+
   override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
     case TypeCheckResult.TypeCheckSuccess =>
       TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 13d00fa..57d7d76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
 import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
 
   override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
   override def toString: String = escapeChar match {
     case '\\' => s"$left LIKE $right"
     case c => s"$left LIKE $right ESCAPE '$c'"
@@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase
 
   override def nullable: Boolean = true
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
   protected lazy val hasNull: Boolean = patterns.contains(null)
 
   protected lazy val cache = patterns.filterNot(_ == null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3d5f812..5956c3e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
@@ -406,6 +407,8 @@ case class Upper(child: Expression)
   override def convert(v: UTF8String): UTF8String = v.toUpperCase
   // scalastyle:on caselocale
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
   }
@@ -432,6 +435,8 @@ case class Lower(child: Expression)
   override def convert(v: UTF8String): UTF8String = v.toLowerCase
   // scalastyle:on caselocale
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 49e2412..372acc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreePattern.IN
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -50,8 +51,9 @@ object ConstantFolding extends Rule[LogicalPlan] {
     case _ => false
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsDown {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(AlwaysProcess.fn, ruleId) {
+    case q: LogicalPlan => q.transformExpressionsDownWithPruning(
+      AlwaysProcess.fn, ruleId) {
       // Skip redundant folding of literals. This rule is technically not necessary. Placing this
       // here avoids running the next rule for Literal values, which would create a new Literal
       // object and running eval unnecessarily.
@@ -83,7 +85,8 @@ object ConstantFolding extends Rule[LogicalPlan] {
  *   in the AND node.
  */
 object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsAllPatterns(LITERAL, FILTER), ruleId) {
     case f: Filter =>
       val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
       if (newCondition.isDefined) {
@@ -210,14 +213,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
     case _ => ExpressionSet(Seq.empty)
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(BINARY_ARITHMETIC), ruleId) {
     case q: LogicalPlan =>
       // We have to respect aggregate expressions which exists in grouping expressions when plan
       // is an Aggregate operator, otherwise the optimized expression could not be derived from
       // grouping expressions.
       // TODO: do not reorder consecutive `Add`s or `Multiply`s with different `failOnError` flags
       val groupingExpressionSet = collectGroupingExpressions(q)
-      q transformExpressionsDown {
+      q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
       case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
         val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
         if (foldables.size > 1) {
@@ -286,8 +290,10 @@ object OptimizeIn extends Rule[LogicalPlan] {
  * 4. Removes `Not` operator.
  */
 object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(AND_OR, NOT), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+      _.containsAnyPattern(AND_OR, NOT), ruleId) {
       case TrueLiteral And e => e
       case e And TrueLiteral => e
       case FalseLiteral Or e => e
@@ -460,7 +466,8 @@ object SimplifyBinaryComparison
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(BINARY_COMPARISON), ruleId) {
     case l: LogicalPlan =>
       lazy val notNullExpressions = ExpressionSet(l match {
         case Filter(fc, _) =>
@@ -470,7 +477,7 @@ object SimplifyBinaryComparison
         case _ => Seq.empty
       })
 
-      l transformExpressionsUp {
+      l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON)) {
         // True with equality
         case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
         case a EqualTo b if canSimplifyComparison(a, b, notNullExpressions) => TrueLiteral
@@ -496,7 +503,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
     case _ => false
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(IF, CASE_WHEN), ruleId) {
     case q: LogicalPlan => q transformExpressionsUp {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
@@ -601,8 +609,10 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
     case _ => false
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+      _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
       case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
           if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
         i.copy(
@@ -713,7 +723,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+    _.containsPattern(LIKE_FAMLIY), ruleId) {
     case l @ Like(input, Literal(pattern, StringType), escapeChar) =>
       if (pattern == null) {
         // If pattern is null, return null value directly, since "col like null" == null.
@@ -740,8 +751,12 @@ object NullPropagation extends Rule[LogicalPlan] {
     case _ => false
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+      || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+      t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+        || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
       case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
         Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
       case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) =>
@@ -917,7 +932,8 @@ object FoldablePropagation extends Rule[LogicalPlan] {
  * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
  */
 object SimplifyCasts extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+    _.containsPattern(CAST), ruleId) {
     case Cast(e, dataType, _) if e.dataType == dataType => e
     case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
       case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
@@ -933,7 +949,8 @@ object SimplifyCasts extends Rule[LogicalPlan] {
  * Removes nodes that are not necessary.
  */
 object RemoveDispensableExpressions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+    _.containsPattern(UNARY_POSITIVE), ruleId) {
     case UnaryPositive(child) => child
   }
 }
@@ -944,8 +961,10 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
  * the inner conversion is overwritten by the outer one.
  */
 object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(UPPER_OR_LOWER), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+      _.containsPattern(UPPER_OR_LOWER), ruleId) {
       case Upper(Upper(child)) => Upper(child)
       case Upper(Lower(child)) => Upper(child)
       case Lower(Upper(child)) => Lower(child)
@@ -986,7 +1005,8 @@ object CombineConcats extends Rule[LogicalPlan] {
     case _ => false
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+    _.containsPattern(CONCAT), ruleId) {
     case concat: Concat if hasNestedConcats(concat) =>
       flattenConcats(concat)
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c775bcc..0f5bc7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -26,9 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
-  FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 884e259..1c997d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -47,15 +47,28 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" ::
       "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
       // Catalyst Optimizer rules
+      "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+      "org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
+      "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
+      "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
+      "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
       "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
       "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
       "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
+      "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
       "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
-      "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil
+      "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" ::
+      "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
+      "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
+      "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
+      "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: Nil
   }
 
   // Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 7d725fa..faf736d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -23,20 +23,33 @@ object TreePattern extends Enumeration  {
 
   // Enum Ids start from 0.
   // Expression patterns (alphabetically ordered)
-  val ATTRIBUTE_REFERENCE = Value(0)
+  val AND_OR: Value = Value(0)
+  val ATTRIBUTE_REFERENCE: Value = Value
+  val BINARY_ARITHMETIC: Value = Value
+  val BINARY_COMPARISON: Value = Value
+  val CASE_WHEN: Value = Value
+  val CAST: Value = Value
+  val CONCAT: Value = Value
+  val COUNT: Value = Value
   val DYNAMIC_PRUNING_SUBQUERY: Value = Value
   val EXISTS_SUBQUERY = Value
-  val EXPRESSION_WITH_RANDOM_SEED = Value
+  val EXPRESSION_WITH_RANDOM_SEED: Value = Value
+  val IF: Value = Value
   val IN: Value = Value
   val IN_SUBQUERY: Value = Value
   val INSET: Value = Value
+  val LIKE_FAMLIY: Value = Value
   val LIST_SUBQUERY: Value = Value
   val LITERAL: Value = Value
+  val NOT: Value = Value
+  val NULL_CHECK: Value = Value
   val NULL_LITERAL: Value = Value
   val PLAN_EXPRESSION: Value = Value
   val SCALAR_SUBQUERY: Value = Value
   val TRUE_OR_FALSE_LITERAL: Value = Value
   val WINDOW_EXPRESSION: Value = Value
+  val UNARY_POSITIVE: Value = Value
+  val UPPER_OR_LOWER: Value = Value
 
   // Logical plan patterns (alphabetically ordered)
   val FILTER: Value = Value

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org