You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/12 03:42:29 UTC

spark git commit: [SPARK-12498][SQL][MINOR] BooleanSimplication simplification

Repository: spark
Updated Branches:
  refs/heads/master 473907adf -> 36d493509


[SPARK-12498][SQL][MINOR] BooleanSimplication simplification

Scala syntax allows binary case classes to be used as infix operator in pattern matching. This PR makes use of this syntax sugar to make `BooleanSimplification` more readable.

Author: Cheng Lian <li...@databricks.com>

Closes #10445 from liancheng/boolean-simplification-simplification.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/36d49350
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/36d49350
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/36d49350

Branch: refs/heads/master
Commit: 36d493509d32d14b54af62f5f65e8fa750e7413d
Parents: 473907a
Author: Cheng Lian <li...@databricks.com>
Authored: Mon Jan 11 18:42:26 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jan 11 18:42:26 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/literals.scala     |   4 +
 .../sql/catalyst/optimizer/Optimizer.scala      | 190 +++++++++----------
 2 files changed, 92 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/36d49350/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 17351ef..e0b0203 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -28,6 +28,10 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types._
 
 object Literal {
+  val TrueLiteral: Literal = Literal(true, BooleanType)
+
+  val FalseLiteral: Literal = Literal(false, BooleanType)
+
   def apply(v: Any): Literal = v match {
     case i: Int => Literal(i, IntegerType)
     case l: Long => Literal(l, LongType)

http://git-wip-us.apache.org/repos/asf/spark/blob/36d49350/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f8121a7..b70bc18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
 
 import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
 import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
@@ -519,112 +520,97 @@ object OptimizeIn extends Rule[LogicalPlan] {
 object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case and @ And(left, right) => (left, right) match {
-        // true && r  =>  r
-        case (Literal(true, BooleanType), r) => r
-        // l && true  =>  l
-        case (l, Literal(true, BooleanType)) => l
-        // false && r  =>  false
-        case (Literal(false, BooleanType), _) => Literal(false)
-        // l && false  =>  false
-        case (_, Literal(false, BooleanType)) => Literal(false)
-        // a && a  =>  a
-        case (l, r) if l fastEquals r => l
-        // a && (not(a) || b) => a && b
-        case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r)
-        case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r)
-        case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r)
-        case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r)
-        // (a || b) && (a || c)  =>  a || (b && c)
-        case _ =>
-          // 1. Split left and right to get the disjunctive predicates,
-          //   i.e. lhs = (a, b), rhs = (a, c)
-          // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
-          // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
-          // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
-          val lhs = splitDisjunctivePredicates(left)
-          val rhs = splitDisjunctivePredicates(right)
-          val common = lhs.filter(e => rhs.exists(e.semanticEquals(_)))
-          if (common.isEmpty) {
-            // No common factors, return the original predicate
-            and
+      case TrueLiteral And e => e
+      case e And TrueLiteral => e
+      case FalseLiteral Or e => e
+      case e Or FalseLiteral => e
+
+      case FalseLiteral And _ => FalseLiteral
+      case _ And FalseLiteral => FalseLiteral
+      case TrueLiteral Or _ => TrueLiteral
+      case _ Or TrueLiteral => TrueLiteral
+
+      case a And b if a.semanticEquals(b) => a
+      case a Or b if a.semanticEquals(b) => a
+
+      case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
+      case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
+      case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
+      case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
+
+      case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
+      case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
+      case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
+      case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
+
+      // Common factor elimination for conjunction
+      case and @ (left And right) =>
+        // 1. Split left and right to get the disjunctive predicates,
+        //   i.e. lhs = (a, b), rhs = (a, c)
+        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
+        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
+        // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
+        val lhs = splitDisjunctivePredicates(left)
+        val rhs = splitDisjunctivePredicates(right)
+        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+        if (common.isEmpty) {
+          // No common factors, return the original predicate
+          and
+        } else {
+          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+          if (ldiff.isEmpty || rdiff.isEmpty) {
+            // (a || b || c || ...) && (a || b) => (a || b)
+            common.reduce(Or)
           } else {
-            val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_)))
-            val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_)))
-            if (ldiff.isEmpty || rdiff.isEmpty) {
-              // (a || b || c || ...) && (a || b) => (a || b)
-              common.reduce(Or)
-            } else {
-              // (a || b || c || ...) && (a || b || d || ...) =>
-              // ((c || ...) && (d || ...)) || a || b
-              (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
-            }
+            // (a || b || c || ...) && (a || b || d || ...) =>
+            // ((c || ...) && (d || ...)) || a || b
+            (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
           }
-      }  // end of And(left, right)
-
-      case or @ Or(left, right) => (left, right) match {
-        // true || r  =>  true
-        case (Literal(true, BooleanType), _) => Literal(true)
-        // r || true  =>  true
-        case (_, Literal(true, BooleanType)) => Literal(true)
-        // false || r  =>  r
-        case (Literal(false, BooleanType), r) => r
-        // l || false  =>  l
-        case (l, Literal(false, BooleanType)) => l
-        // a || a => a
-        case (l, r) if l fastEquals r => l
-        // (a && b) || (a && c)  =>  a && (b || c)
-        case _ =>
-           // 1. Split left and right to get the conjunctive predicates,
-           //   i.e.  lhs = (a, b), rhs = (a, c)
-           // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
-           // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
-           // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
-          val lhs = splitConjunctivePredicates(left)
-          val rhs = splitConjunctivePredicates(right)
-          val common = lhs.filter(e => rhs.exists(e.semanticEquals(_)))
-          if (common.isEmpty) {
-            // No common factors, return the original predicate
-            or
+        }
+
+      // Common factor elimination for disjunction
+      case or @ (left Or right) =>
+        // 1. Split left and right to get the conjunctive predicates,
+        //   i.e.  lhs = (a, b), rhs = (a, c)
+        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
+        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
+        // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
+        val lhs = splitConjunctivePredicates(left)
+        val rhs = splitConjunctivePredicates(right)
+        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+        if (common.isEmpty) {
+          // No common factors, return the original predicate
+          or
+        } else {
+          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+          if (ldiff.isEmpty || rdiff.isEmpty) {
+            // (a && b) || (a && b && c && ...) => a && b
+            common.reduce(And)
           } else {
-            val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_)))
-            val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_)))
-            if (ldiff.isEmpty || rdiff.isEmpty) {
-              // (a && b) || (a && b && c && ...) => a && b
-              common.reduce(And)
-            } else {
-              // (a && b && c && ...) || (a && b && d && ...) =>
-              // ((c && ...) || (d && ...)) && a && b
-              (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
-            }
+            // (a && b && c && ...) || (a && b && d && ...) =>
+            // ((c && ...) || (d && ...)) && a && b
+            (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
           }
-      }  // end of Or(left, right)
-
-      case not @ Not(exp) => exp match {
-        // not(true)  =>  false
-        case Literal(true, BooleanType) => Literal(false)
-        // not(false)  =>  true
-        case Literal(false, BooleanType) => Literal(true)
-        // not(l > r)  =>  l <= r
-        case GreaterThan(l, r) => LessThanOrEqual(l, r)
-        // not(l >= r)  =>  l < r
-        case GreaterThanOrEqual(l, r) => LessThan(l, r)
-        // not(l < r)  =>  l >= r
-        case LessThan(l, r) => GreaterThanOrEqual(l, r)
-        // not(l <= r)  =>  l > r
-        case LessThanOrEqual(l, r) => GreaterThan(l, r)
-        // not(l || r) => not(l) && not(r)
-        case Or(l, r) => And(Not(l), Not(r))
-        // not(l && r) => not(l) or not(r)
-        case And(l, r) => Or(Not(l), Not(r))
-        // not(not(e))  =>  e
-        case Not(e) => e
-        case _ => not
-      }  // end of Not(exp)
-
-      // if (true) a else b  =>  a
-      // if (false) a else b  =>  b
-      case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
+        }
+
+      case Not(TrueLiteral) => FalseLiteral
+      case Not(FalseLiteral) => TrueLiteral
+
+      case Not(a GreaterThan b) => LessThanOrEqual(a, b)
+      case Not(a GreaterThanOrEqual b) => LessThan(a, b)
+
+      case Not(a LessThan b) => GreaterThanOrEqual(a, b)
+      case Not(a LessThanOrEqual b) => GreaterThan(a, b)
+
+      case Not(a Or b) => And(Not(a), Not(b))
+      case Not(a And b) => Or(Not(a), Not(b))
+
+      case Not(Not(e)) => e
+
+      case If(TrueLiteral, trueValue, _) => trueValue
+      case If(FalseLiteral, _, falseValue) => falseValue
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org