You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/13 03:45:57 UTC

spark git commit: [SPARK-12788][SQL] Simplify BooleanEquality by using casts.

Repository: spark
Updated Branches:
  refs/heads/master 924708496 -> b3b9ad23c


[SPARK-12788][SQL] Simplify BooleanEquality by using casts.

Author: Reynold Xin <rx...@databricks.com>

Closes #10730 from rxin/SPARK-12788.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3b9ad23
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3b9ad23
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3b9ad23

Branch: refs/heads/master
Commit: b3b9ad23cffc1c6d83168487093e4c03d49e1c2c
Parents: 9247084
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Jan 12 18:45:55 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jan 12 18:45:55 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 30 ++++----------------
 .../analysis/HiveTypeCoercionSuite.scala        | 28 +++++++++++++++++-
 2 files changed, 32 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b3b9ad23/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e9e2067..980b5d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -482,27 +482,6 @@ object HiveTypeCoercion {
     private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
     private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
 
-    private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
-      CaseKeyWhen(numericExpr, Seq(
-        Literal(trueValues.head), booleanExpr,
-        Literal(falseValues.head), Not(booleanExpr),
-        Literal(false)))
-    }
-
-    private def transform(booleanExpr: Expression, numericExpr: Expression) = {
-      If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
-        Literal.create(null, BooleanType),
-        buildCaseKeyWhen(booleanExpr, numericExpr))
-    }
-
-    private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
-      CaseWhen(Seq(
-        And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
-        Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
-        buildCaseKeyWhen(booleanExpr, numericExpr)
-      ))
-    }
-
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
@@ -511,6 +490,7 @@ object HiveTypeCoercion {
       // all other cases are considered as false.
 
       // We may simplify the expression if one side is literal numeric values
+      // TODO: Maybe these rules should go into the optimizer.
       case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
         if trueValues.contains(value) => bool
       case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
@@ -529,13 +509,13 @@ object HiveTypeCoercion {
         if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
 
       case EqualTo(left @ BooleanType(), right @ NumericType()) =>
-        transform(left , right)
+        EqualTo(Cast(left, right.dataType), right)
       case EqualTo(left @ NumericType(), right @ BooleanType()) =>
-        transform(right, left)
+        EqualTo(left, Cast(right, left.dataType))
       case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
-        transformNullSafe(left, right)
+        EqualNullSafe(Cast(left, right.dataType), right)
       case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
-        transformNullSafe(right, left)
+        EqualNullSafe(left, Cast(right, left.dataType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b3b9ad23/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 23b11af..40378c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -320,7 +320,33 @@ class HiveTypeCoercionSuite extends PlanTest {
     )
   }
 
-  test("type coercion simplification for equal to") {
+  test("BooleanEquality type cast") {
+    val be = HiveTypeCoercion.BooleanEquality
+    // Use something more than a literal to avoid triggering the simplification rules.
+    val one = Add(Literal(Decimal(1)), Literal(Decimal(0)))
+
+    ruleTest(be,
+      EqualTo(Literal(true), one),
+      EqualTo(Cast(Literal(true), one.dataType), one)
+    )
+
+    ruleTest(be,
+      EqualTo(one, Literal(true)),
+      EqualTo(one, Cast(Literal(true), one.dataType))
+    )
+
+    ruleTest(be,
+      EqualNullSafe(Literal(true), one),
+      EqualNullSafe(Cast(Literal(true), one.dataType), one)
+    )
+
+    ruleTest(be,
+      EqualNullSafe(one, Literal(true)),
+      EqualNullSafe(one, Cast(Literal(true), one.dataType))
+    )
+  }
+
+  test("BooleanEquality simplification") {
     val be = HiveTypeCoercion.BooleanEquality
 
     ruleTest(be,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org