You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/04/23 08:34:34 UTC
[spark] branch master updated: [SPARK-35078][SQL] Add tree
traversal pruning in expression rules
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9af338c [SPARK-35078][SQL] Add tree traversal pruning in expression rules
9af338c is described below
commit 9af338cd685bce26abbc2dd4d077bde5068157b1
Author: Yingyi Bu <yi...@databricks.com>
AuthorDate: Fri Apr 23 16:33:58 2021 +0800
[SPARK-35078][SQL] Add tree traversal pruning in expression rules
### What changes were proposed in this pull request?
Added the following TreePattern enums:
- AND_OR
- BINARY_ARITHMETIC
- BINARY_COMPARISON
- CASE_WHEN
- CAST
- CONCAT
- COUNT
- IF
- LIKE_FAMLIY
- NOT
- NULL_CHECK
- UNARY_POSITIVE
- UPPER_OR_LOWER
Used them in the following rules:
- ConstantPropagation
- ReorderAssociativeOperator
- BooleanSimplification
- SimplifyBinaryComparison
- SimplifyCaseConversionExpressions
- SimplifyConditionals
- PushFoldableIntoBranches
- LikeSimplification
- NullPropagation
- SimplifyCasts
- RemoveDispensableExpressions
- CombineConcats
### Why are the changes needed?
Reduce the number of tree traversals and hence improve the query compilation latency.
### How was this patch tested?
Existing tests.
Closes #32280 from sigmod/expression.
Authored-by: Yingyi Bu <yi...@databricks.com>
Signed-off-by: Gengliang Wang <lt...@gmail.com>
---
.../spark/sql/catalyst/expressions/Cast.scala | 3 ++
.../sql/catalyst/expressions/aggregate/Count.scala | 3 ++
.../sql/catalyst/expressions/arithmetic.scala | 6 +++
.../expressions/collectionOperations.scala | 3 ++
.../expressions/conditionalExpressions.scala | 5 ++
.../sql/catalyst/expressions/nullExpressions.scala | 5 ++
.../sql/catalyst/expressions/objects/objects.scala | 3 ++
.../sql/catalyst/expressions/predicates.scala | 14 +++--
.../catalyst/expressions/regexpExpressions.scala | 5 ++
.../catalyst/expressions/stringExpressions.scala | 5 ++
.../spark/sql/catalyst/optimizer/expressions.scala | 62 ++++++++++++++--------
.../plans/logical/basicLogicalOperators.scala | 4 +-
.../sql/catalyst/rules/RuleIdCollection.scala | 15 +++++-
.../spark/sql/catalyst/trees/TreePatterns.scala | 17 +++++-
14 files changed, 120 insertions(+), 30 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 5d799c7..30317c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
+ final override val nodePatterns: Seq[TreePattern] = Seq(CAST)
+
override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1d13155..dfdd828 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)
+
// Return data type.
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 954a4b9..10b4a7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
+ UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)
override def dataType: DataType = child.dataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)
@@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
override def dataType: DataType = left.dataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)
+
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/** Name of the function for this expression on a [[Decimal]] type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 125e796..57bfbcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
+ final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
+
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index e708d56..3e356f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def third: Expression = falseValue
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
+ final override val nodePatterns : Seq[TreePattern] = Seq(IF)
+
override def checkInputDataTypes(): TypeCheckResult = {
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
@@ -139,6 +142,8 @@ case class CaseWhen(
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
+ final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)
+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 2c2df6b..d4a02c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def eval(input: InternalRow): Any = {
child.eval(input) == null
}
@@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def eval(input: InternalRow): Any = {
child.eval(input) != null
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5ae0cef..a17ac20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
override def foldable: Boolean = false
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def flatArguments: Iterator[Any] = Iterator(child)
private val errMsg = "Null value appeared in non-nullable field:" +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index cb710ad..4885f77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -309,6 +309,8 @@ case class Not(child: Expression)
override def inputTypes: Seq[DataType] = Seq(BooleanType)
+ final override val nodePatterns: Seq[TreePattern] = Seq(NOT)
+
// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
@@ -435,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
- override val nodePatterns: Seq[TreePattern] = Seq(IN)
+ final override val nodePatterns: Seq[TreePattern] = Seq(IN)
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
@@ -548,7 +550,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull
- override val nodePatterns: Seq[TreePattern] = Seq(INSET)
+ final override val nodePatterns: Seq[TreePattern] = Seq(INSET)
protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
@@ -666,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
override def sqlOperator: String = "AND"
+ final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
@@ -752,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
override def sqlOperator: String = "OR"
+ final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
@@ -823,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)
+
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 13d00fa..57d7d76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
+ final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
@@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase
override def nullable: Boolean = true
+ final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
protected lazy val hasNull: Boolean = patterns.contains(null)
protected lazy val cache = patterns.filterNot(_ == null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3d5f812..5956c3e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
@@ -406,6 +407,8 @@ case class Upper(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toUpperCase
// scalastyle:on caselocale
+ final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
}
@@ -432,6 +435,8 @@ case class Lower(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toLowerCase
// scalastyle:on caselocale
+ final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 49e2412..372acc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreePattern.IN
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -50,8 +51,9 @@ object ConstantFolding extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(AlwaysProcess.fn, ruleId) {
+ case q: LogicalPlan => q.transformExpressionsDownWithPruning(
+ AlwaysProcess.fn, ruleId) {
// Skip redundant folding of literals. This rule is technically not necessary. Placing this
// here avoids running the next rule for Literal values, which would create a new Literal
// object and running eval unnecessarily.
@@ -83,7 +85,8 @@ object ConstantFolding extends Rule[LogicalPlan] {
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+ _.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
if (newCondition.isDefined) {
@@ -210,14 +213,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
case _ => ExpressionSet(Seq.empty)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(BINARY_ARITHMETIC), ruleId) {
case q: LogicalPlan =>
// We have to respect aggregate expressions which exists in grouping expressions when plan
// is an Aggregate operator, otherwise the optimized expression could not be derived from
// grouping expressions.
// TODO: do not reorder consecutive `Add`s or `Multiply`s with different `failOnError` flags
val groupingExpressionSet = collectGroupingExpressions(q)
- q transformExpressionsDown {
+ q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
@@ -286,8 +290,10 @@ object OptimizeIn extends Rule[LogicalPlan] {
* 4. Removes `Not` operator.
*/
object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(AND_OR, NOT), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsAnyPattern(AND_OR, NOT), ruleId) {
case TrueLiteral And e => e
case e And TrueLiteral => e
case FalseLiteral Or e => e
@@ -460,7 +466,8 @@ object SimplifyBinaryComparison
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(BINARY_COMPARISON), ruleId) {
case l: LogicalPlan =>
lazy val notNullExpressions = ExpressionSet(l match {
case Filter(fc, _) =>
@@ -470,7 +477,7 @@ object SimplifyBinaryComparison
case _ => Seq.empty
})
- l transformExpressionsUp {
+ l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON)) {
// True with equality
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
case a EqualTo b if canSimplifyComparison(a, b, notNullExpressions) => TrueLiteral
@@ -496,7 +503,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(IF, CASE_WHEN), ruleId) {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
@@ -601,8 +609,10 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
@@ -713,7 +723,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(LIKE_FAMLIY), ruleId) {
case l @ Like(input, Literal(pattern, StringType), escapeChar) =>
if (pattern == null) {
// If pattern is null, return null value directly, since "col like null" == null.
@@ -740,8 +751,12 @@ object NullPropagation extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+ || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+ || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) =>
@@ -917,7 +932,8 @@ object FoldablePropagation extends Rule[LogicalPlan] {
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
*/
object SimplifyCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(CAST), ruleId) {
case Cast(e, dataType, _) if e.dataType == dataType => e
case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
@@ -933,7 +949,8 @@ object SimplifyCasts extends Rule[LogicalPlan] {
* Removes nodes that are not necessary.
*/
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(UNARY_POSITIVE), ruleId) {
case UnaryPositive(child) => child
}
}
@@ -944,8 +961,10 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
* the inner conversion is overwritten by the outer one.
*/
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(UPPER_OR_LOWER), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsPattern(UPPER_OR_LOWER), ruleId) {
case Upper(Upper(child)) => Upper(child)
case Upper(Lower(child)) => Upper(child)
case Lower(Upper(child)) => Lower(child)
@@ -986,7 +1005,8 @@ object CombineConcats extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(CONCAT), ruleId) {
case concat: Concat if hasNestedConcats(concat) =>
flattenConcats(concat)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c775bcc..0f5bc7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -26,9 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
- FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 884e259..1c997d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -47,15 +47,28 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" ::
"org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
// Catalyst Optimizer rules
+ "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+ "org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
+ "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
+ "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
+ "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
+ "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
- "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil
+ "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: Nil
}
// Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 7d725fa..faf736d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -23,20 +23,33 @@ object TreePattern extends Enumeration {
// Enum Ids start from 0.
// Expression patterns (alphabetically ordered)
- val ATTRIBUTE_REFERENCE = Value(0)
+ val AND_OR: Value = Value(0)
+ val ATTRIBUTE_REFERENCE: Value = Value
+ val BINARY_ARITHMETIC: Value = Value
+ val BINARY_COMPARISON: Value = Value
+ val CASE_WHEN: Value = Value
+ val CAST: Value = Value
+ val CONCAT: Value = Value
+ val COUNT: Value = Value
val DYNAMIC_PRUNING_SUBQUERY: Value = Value
val EXISTS_SUBQUERY = Value
- val EXPRESSION_WITH_RANDOM_SEED = Value
+ val EXPRESSION_WITH_RANDOM_SEED: Value = Value
+ val IF: Value = Value
val IN: Value = Value
val IN_SUBQUERY: Value = Value
val INSET: Value = Value
+ val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
val LITERAL: Value = Value
+ val NOT: Value = Value
+ val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
val PLAN_EXPRESSION: Value = Value
val SCALAR_SUBQUERY: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
+ val UNARY_POSITIVE: Value = Value
+ val UPPER_OR_LOWER: Value = Value
// Logical plan patterns (alphabetically ordered)
val FILTER: Value = Value
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org