You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/03 02:48:56 UTC

spark git commit: [SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN

Repository: spark
Updated Branches:
  refs/heads/master a3e293542 -> f70503761


[SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN

## What changes were proposed in this pull request?

Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too.

**Before**
```
scala> sql("SELECT IF(null, 1, 0)").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4]
:     +- INPUT
+- Scan OneRowRelation[]
scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14]
:     +- INPUT
+- Scan OneRowRelation[]
```

**After**
```
scala> sql("SELECT IF(null, 1, 0)").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4]
:     +- INPUT
+- Scan OneRowRelation[]
scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4]
:     +- INPUT
+- Scan OneRowRelation[]
```

**Hive**
```
hive> select if(null,1,2);
OK
2
hive> select case when cast(null as boolean) then 1 else 2 end;
OK
2
```

## How was this patch tested?

Pass the Jenkins tests (including new extended test cases).

Author: Dongjoon Hyun <do...@apache.org>

Closes #12122 from dongjoon-hyun/SPARK-14338.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7050376
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7050376
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7050376

Branch: refs/heads/master
Commit: f705037617d55bb479ec60bcb1e55c736224be94
Parents: a3e2935
Author: Dongjoon Hyun <do...@apache.org>
Authored: Sat Apr 2 17:48:53 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Apr 2 17:48:53 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/optimizer/Optimizer.scala    | 13 ++++++++++---
 .../optimizer/SimplifyConditionalSuite.scala        | 16 +++++++++++-----
 2 files changed, 21 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f7050376/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 326933e..a5ab390 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
  * Null value propagation from bottom to top of the expression tree.
  */
 object NullPropagation extends Rule[LogicalPlan] {
-  def nonNullLiteral(e: Expression): Boolean = e match {
+  private def nonNullLiteral(e: Expression): Boolean = e match {
     case Literal(null, _) => false
     case _ => true
   }
@@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
  * Simplifies conditional expressions (if / case).
  */
 object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+  private def falseOrNullLiteral(e: Expression): Boolean = e match {
+    case FalseLiteral => true
+    case Literal(null, _) => true
+    case _ => false
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
+      case If(Literal(null, _), _, falseValue) => falseValue
 
-      case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+      case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
         // If there are branches that are always false, remove them.
         // If there are no more branches left, just use the else value.
         // Note that these two are handled together here in a single case statement because
         // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
-        val newBranches = branches.filter(_._1 != FalseLiteral)
+        val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
         if (newBranches.isEmpty) {
           elseValue.getOrElse(Literal.create(null, e.dataType))
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/f7050376/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index d436b62..33239c0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, NullType}
 
 
 class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
   private val trueBranch = (TrueLiteral, Literal(5))
   private val normalBranch = (NonFoldableLiteral(true), Literal(10))
   private val unreachableBranch = (FalseLiteral, Literal(20))
+  private val nullBranch = (Literal(null, NullType), Literal(30))
 
   test("simplify if") {
     assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
     assertEquivalent(
       If(FalseLiteral, Literal(10), Literal(20)),
       Literal(20))
+
+    assertEquivalent(
+      If(Literal(null, NullType), Literal(10), Literal(20)),
+      Literal(20))
   }
 
   test("remove unreachable branches") {
     // i.e. removing branches whose conditions are always false
     assertEquivalent(
-      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
       CaseWhen(normalBranch :: Nil, None))
   }
 
   test("remove entire CaseWhen if only the else branch is reachable") {
     assertEquivalent(
-      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
       Literal(30))
 
     assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
 
   test("remove entire CaseWhen if the first branch is always true") {
     assertEquivalent(
-      CaseWhen(trueBranch :: normalBranch :: Nil, None),
+      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
       Literal(5))
 
     // Test branch elimination and simplification in combination
     assertEquivalent(
-      CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
+        :: Nil, None),
       Literal(5))
 
     // Make sure this doesn't trigger if there is a non-foldable branch before the true branch


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org