You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/11/07 08:19:20 UTC

[spark] branch master updated: [SPARK-36665][SQL] Add more Not operator simplifications

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 977dd05  [SPARK-36665][SQL] Add more Not operator simplifications
977dd05 is described below

commit 977dd054ed0946b62e62d2d480dbf25598545a5e
Author: Kazuyuki Tanimura <kt...@apple.com>
AuthorDate: Sun Nov 7 01:17:58 2021 -0700

    [SPARK-36665][SQL] Add more Not operator simplifications
    
    ### What changes were proposed in this pull request?
    This PR proposes to add more Not operator simplifications in `BooleanSimplification` by applying the following rules
      - Not(null) == null
        - e.g. IsNull(Not(...)) can be IsNull(...)
      - (Not(a) = b) == (a = Not(b))
        - e.g. Not(...) = true can be (...) = false
      - (a != b) == (a = Not(b))
        - e.g. (...) != true can be (...) = false
    
    ### Why are the changes needed?
    This PR simplifies SQL statements that includes Not operators.
    In addition, the following query does not push down the filter in the current implementation
    ```
    SELECT * FROM t WHERE (not boolean_col) <=> null
    ```
    although the following equivalent query pushes down the filter as expected.
    ```
    SELECT * FROM t WHERE boolean_col <=> null
    ```
    That is because the first query creates `IsNull(Not(boolean_col))` in the current implementation, which should be able to get simplified further to `IsNull(boolean_col)`
    This PR helps optimizing such cases.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests
    ```
    build/sbt "testOnly *BooleanSimplificationSuite  -- -z SPARK-36665"
    ```
    
    Closes #33930 from kazuyukitanimura/SPARK-36665.
    
    Authored-by: Kazuyuki Tanimura <kt...@apple.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   2 +
 .../spark/sql/catalyst/optimizer/expressions.scala |  80 ++++++++++
 .../sql/catalyst/rules/RuleIdCollection.scala      |   2 +
 .../catalyst/optimizer/NotPropagationSuite.scala   | 176 +++++++++++++++++++++
 .../optimizer/NullDownPropagationSuite.scala       |  59 +++++++
 5 files changed, 319 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 298da4f..5386907 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -99,6 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
         OptimizeRepartition,
         TransposeWindow,
         NullPropagation,
+        NullDownPropagation,
         ConstantPropagation,
         FoldablePropagation,
         OptimizeIn,
@@ -106,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
         EliminateAggregateFilter,
         ReorderAssociativeOperator,
         LikeSimplification,
+        NotPropagation,
         BooleanSimplification,
         SimplifyConditionals,
         PushFoldableIntoBranches,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 0ec8bad..a32306f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -447,6 +447,53 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
 
 
 /**
+ * Move/Push `Not` operator if it's beneficial.
+ */
+object NotPropagation extends Rule[LogicalPlan] {
+  // Given argument x, return true if expression Not(x) can be simplified
+  // E.g. let x == Not(y), then canSimplifyNot(x) == true because Not(x) == Not(Not(y)) == y
+  // For the case of x = EqualTo(a, b), recursively check each child expression
+  // Extra nullable check is required for EqualNullSafe because
+  // Not(EqualNullSafe(e, null)) is different from EqualNullSafe(e, Not(null))
+  private def canSimplifyNot(x: Expression): Boolean = x match {
+    case Literal(_, BooleanType) | Literal(_, NullType) => true
+    case _: Not | _: IsNull | _: IsNotNull | _: And | _: Or => true
+    case _: GreaterThan | _: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual => true
+    case EqualTo(a, b) if canSimplifyNot(a) || canSimplifyNot(b) => true
+    case EqualNullSafe(a, b)
+      if !a.nullable && !b.nullable && (canSimplifyNot(a) || canSimplifyNot(b)) => true
+    case _ => false
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(NOT), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsDownWithPruning(_.containsPattern(NOT), ruleId) {
+      // Move `Not` from one side of `EqualTo`/`EqualNullSafe` to the other side if it's beneficial.
+      // E.g. `EqualTo(Not(a), b)` where `b = Not(c)`, it will become
+      // `EqualTo(a, Not(b))` => `EqualTo(a, Not(Not(c)))` => `EqualTo(a, c)`
+      // In addition, `if canSimplifyNot(b)` checks if the optimization can converge
+      // that avoids the situation two conditions are returning to each other.
+      case EqualTo(Not(a), b) if !canSimplifyNot(a) && canSimplifyNot(b) => EqualTo(a, Not(b))
+      case EqualTo(a, Not(b)) if canSimplifyNot(a) && !canSimplifyNot(b) => EqualTo(Not(a), b)
+      case EqualNullSafe(Not(a), b) if !canSimplifyNot(a) && canSimplifyNot(b) =>
+        EqualNullSafe(a, Not(b))
+      case EqualNullSafe(a, Not(b)) if canSimplifyNot(a) && !canSimplifyNot(b) =>
+        EqualNullSafe(Not(a), b)
+
+      // Push `Not` to one side of `EqualTo`/`EqualNullSafe` if it's beneficial.
+      // E.g. Not(EqualTo(x, false)) => EqualTo(x, true)
+      case Not(EqualTo(a, b)) if canSimplifyNot(b) => EqualTo(a, Not(b))
+      case Not(EqualTo(a, b)) if canSimplifyNot(a) => EqualTo(Not(a), b)
+      case Not(EqualNullSafe(a, b)) if !a.nullable && !b.nullable && canSimplifyNot(b) =>
+        EqualNullSafe(a, Not(b))
+      case Not(EqualNullSafe(a, b)) if !a.nullable && !b.nullable && canSimplifyNot(a) =>
+        EqualNullSafe(Not(a), b)
+    }
+  }
+}
+
+
+/**
  * Simplifies binary comparisons with semantically-equal expressions:
  * 1) Replace '<=>' with 'true' literal.
  * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
@@ -814,6 +861,39 @@ object NullPropagation extends Rule[LogicalPlan] {
 
 
 /**
+ * Unwrap the input of IsNull/IsNotNull if the input is NullIntolerant
+ * E.g. IsNull(Not(null)) == IsNull(null)
+ */
+object NullDownPropagation extends Rule[LogicalPlan] {
+  // Return true iff the expression returns non-null result for all non-null inputs.
+  // Not all `NullIntolerant` can be propagated. E.g. `Cast` is `NullIntolerant`; however,
+  // cast('Infinity' as integer) is null. Hence, `Cast` is not supported `NullIntolerant`.
+  // `ExtractValue` is also not supported. E.g. the planner may resolve column `a` to `a#123`,
+  // then IsNull(a#123) cannot be optimized.
+  // Applying to `EqualTo` is too disruptive for [SPARK-32290] optimization, not supported for now.
+  // If e has multiple children, the deterministic check is required because optimizing
+  // IsNull(a > b) to Or(IsNull(a), IsNull(b)), for example, may cause skipping the evaluation of b
+  private def supportedNullIntolerant(e: NullIntolerant): Boolean = (e match {
+    case _: Not => true
+    case _: GreaterThan | _: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual
+      if e.deterministic => true
+    case _ => false
+  }) && e.children.nonEmpty
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(NULL_CHECK), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsDownWithPruning(
+      _.containsPattern(NULL_CHECK), ruleId) {
+      case IsNull(e: NullIntolerant) if supportedNullIntolerant(e) =>
+        e.children.map(IsNull(_): Expression).reduceLeft(Or)
+      case IsNotNull(e: NullIntolerant) if supportedNullIntolerant(e) =>
+        e.children.map(IsNotNull(_): Expression).reduceLeft(And)
+    }
+  }
+}
+
+
+/**
  * Replace attributes with aliases of the original foldable expressions if possible.
  * Other optimizations will take advantage of the propagated foldable expressions. For example,
  * this rule can optimize
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index d207ebc..9792545 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -115,6 +115,8 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
       "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
       "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
+      "org.apache.spark.sql.catalyst.optimizer.NotPropagation" ::
+      "org.apache.spark.sql.catalyst.optimizer.NullDownPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala
new file mode 100644
index 0000000..d950609
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.BooleanType
+
+class NotPropagationSuite extends PlanTest with ExpressionEvalHelper {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("AnalysisNodes", Once, EliminateSubqueryAliases) ::
+      Batch("Not Propagation", FixedPoint(50),
+        NullPropagation,
+        NullDownPropagation,
+        ConstantFolding,
+        SimplifyConditionals,
+        BooleanSimplification,
+        NotPropagation,
+        PruneFilters) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string,
+    'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean)
+
+  val testRelationWithData = LocalRelation.fromExternalRows(
+    testRelation.output, Seq(Row(1, 2, 3, "abc"))
+  )
+
+  private def checkCondition(input: Expression, expected: LogicalPlan): Unit = {
+    val plan = testRelationWithData.where(input).analyze
+    val actual = Optimize.execute(plan)
+    comparePlans(actual, expected)
+  }
+
+  private def checkCondition(input: Expression, expected: Expression): Unit = {
+    val plan = testRelation.where(input).analyze
+    val actual = Optimize.execute(plan)
+    val correctAnswer = testRelation.where(expected).analyze
+    comparePlans(actual, correctAnswer)
+  }
+
+  test("Using (Not(a) === b) == (a === Not(b)), (Not(a) <=> b) == (a <=> Not(b)) rules") {
+    checkCondition(Not('e) === Literal(true), 'e === Literal(false))
+    checkCondition(Not('e) === Literal(false), 'e === Literal(true))
+    checkCondition(Not('e) === Literal(null, BooleanType), testRelation)
+    checkCondition(Literal(true) === Not('e), Literal(false) === 'e)
+    checkCondition(Literal(false) === Not('e), Literal(true) === 'e)
+    checkCondition(Literal(null, BooleanType) === Not('e), testRelation)
+    checkCondition(Not('e) <=> Literal(true), 'e <=> Literal(false))
+    checkCondition(Not('e) <=> Literal(false), 'e <=> Literal(true))
+    checkCondition(Not('e) <=> Literal(null, BooleanType), IsNull('e))
+    checkCondition(Literal(true) <=> Not('e), Literal(false) <=> 'e)
+    checkCondition(Literal(false) <=> Not('e), Literal(true) <=> 'e)
+    checkCondition(Literal(null, BooleanType) <=> Not('e), IsNull('e))
+
+    checkCondition(Not('e) === Not('f), 'e === 'f)
+    checkCondition(Not('e) <=> Not('f), 'e <=> 'f)
+
+    checkCondition(IsNull('e) === Not('f), IsNotNull('e) === 'f)
+    checkCondition(Not('e) === IsNull('f), 'e === IsNotNull('f))
+    checkCondition(IsNull('e) <=> Not('f), IsNotNull('e) <=> 'f)
+    checkCondition(Not('e) <=> IsNull('f), 'e <=> IsNotNull('f))
+
+    checkCondition(IsNotNull('e) === Not('f), IsNull('e) === 'f)
+    checkCondition(Not('e) === IsNotNull('f), 'e === IsNull('f))
+    checkCondition(IsNotNull('e) <=> Not('f), IsNull('e) <=> 'f)
+    checkCondition(Not('e) <=> IsNotNull('f), 'e <=> IsNull('f))
+
+    checkCondition(Not('e) === Not(And('f, 'g)), 'e === And('f, 'g))
+    checkCondition(Not(And('e, 'f)) === Not('g), And('e, 'f) === 'g)
+    checkCondition(Not('e) <=> Not(And('f, 'g)), 'e <=> And('f, 'g))
+    checkCondition(Not(And('e, 'f)) <=> Not('g), And('e, 'f) <=> 'g)
+
+    checkCondition(Not('e) === Not(Or('f, 'g)), 'e === Or('f, 'g))
+    checkCondition(Not(Or('e, 'f)) === Not('g), Or('e, 'f) === 'g)
+    checkCondition(Not('e) <=> Not(Or('f, 'g)), 'e <=> Or('f, 'g))
+    checkCondition(Not(Or('e, 'f)) <=> Not('g), Or('e, 'f) <=> 'g)
+
+    checkCondition(('a > 'b) === Not('f), ('a <= 'b) === 'f)
+    checkCondition(Not('e) === ('a > 'b), 'e === ('a <= 'b))
+    checkCondition(('a > 'b) <=> Not('f), ('a <= 'b) <=> 'f)
+    checkCondition(Not('e) <=> ('a > 'b), 'e <=> ('a <= 'b))
+
+    checkCondition(('a >= 'b) === Not('f), ('a < 'b) === 'f)
+    checkCondition(Not('e) === ('a >= 'b), 'e === ('a < 'b))
+    checkCondition(('a >= 'b) <=> Not('f), ('a < 'b) <=> 'f)
+    checkCondition(Not('e) <=> ('a >= 'b), 'e <=> ('a < 'b))
+
+    checkCondition(('a < 'b) === Not('f), ('a >= 'b) === 'f)
+    checkCondition(Not('e) === ('a < 'b), 'e === ('a >= 'b))
+    checkCondition(('a < 'b) <=> Not('f), ('a >= 'b) <=> 'f)
+    checkCondition(Not('e) <=> ('a < 'b), 'e <=> ('a >= 'b))
+
+    checkCondition(('a <= 'b) === Not('f), ('a > 'b) === 'f)
+    checkCondition(Not('e) === ('a <= 'b), 'e === ('a > 'b))
+    checkCondition(('a <= 'b) <=> Not('f), ('a > 'b) <=> 'f)
+    checkCondition(Not('e) <=> ('a <= 'b), 'e <=> ('a > 'b))
+  }
+
+  test("Using (a =!= b) == (a === Not(b)), Not(a <=> b) == (a <=> Not(b)) rules") {
+    checkCondition('e =!= Literal(true), 'e === Literal(false))
+    checkCondition('e =!= Literal(false), 'e === Literal(true))
+    checkCondition('e =!= Literal(null, BooleanType), testRelation)
+    checkCondition(Literal(true) =!= 'e, Literal(false) === 'e)
+    checkCondition(Literal(false) =!= 'e, Literal(true) === 'e)
+    checkCondition(Literal(null, BooleanType) =!= 'e, testRelation)
+    checkCondition(Not(('a <=> 'b) <=> Literal(true)), ('a <=> 'b) <=> Literal(false))
+    checkCondition(Not(('a <=> 'b) <=> Literal(false)), ('a <=> 'b) <=> Literal(true))
+    checkCondition(Not(('a <=> 'b) <=> Literal(null, BooleanType)), testRelationWithData)
+    checkCondition(Not(Literal(true) <=> ('a <=> 'b)), Literal(false) <=> ('a <=> 'b))
+    checkCondition(Not(Literal(false) <=> ('a <=> 'b)), Literal(true) <=> ('a <=> 'b))
+    checkCondition(Not(Literal(null, BooleanType) <=> IsNull('e)), testRelationWithData)
+
+    checkCondition('e =!= Not('f), 'e === 'f)
+    checkCondition(Not('e) =!= 'f, 'e === 'f)
+    checkCondition(Not(('a <=> 'b) <=> Not(('b <=> 'c))), ('a <=> 'b) <=> ('b <=> 'c))
+    checkCondition(Not(Not(('a <=> 'b)) <=> ('b <=> 'c)), ('a <=> 'b) <=> ('b <=> 'c))
+
+    checkCondition('e =!= IsNull('f), 'e === IsNotNull('f))
+    checkCondition(IsNull('e) =!= 'f, IsNotNull('e) === 'f)
+    checkCondition(Not(('a <=> 'b) <=> IsNull('f)), ('a <=> 'b) <=> IsNotNull('f))
+    checkCondition(Not(IsNull('e) <=> ('b <=> 'c)), IsNotNull('e) <=> ('b <=> 'c))
+
+    checkCondition('e =!= IsNotNull('f), 'e === IsNull('f))
+    checkCondition(IsNotNull('e) =!= 'f, IsNull('e) === 'f)
+    checkCondition(Not(('a <=> 'b) <=> IsNotNull('f)), ('a <=> 'b) <=> IsNull('f))
+    checkCondition(Not(IsNotNull('e) <=> ('b <=> 'c)), IsNull('e) <=> ('b <=> 'c))
+
+    checkCondition('e =!= Not(And('f, 'g)), 'e === And('f, 'g))
+    checkCondition(Not(And('e, 'f)) =!= 'g, And('e, 'f) === 'g)
+    checkCondition('e =!= Not(Or('f, 'g)), 'e === Or('f, 'g))
+    checkCondition(Not(Or('e, 'f)) =!= 'g, Or('e, 'f) === 'g)
+
+    checkCondition(('a > 'b) =!= 'f, ('a <= 'b) === 'f)
+    checkCondition('e =!= ('a > 'b), 'e === ('a <= 'b))
+    checkCondition(('a >= 'b) =!= 'f, ('a < 'b) === 'f)
+    checkCondition('e =!= ('a >= 'b), 'e === ('a < 'b))
+    checkCondition(('a < 'b) =!= 'f, ('a >= 'b) === 'f)
+    checkCondition('e =!= ('a < 'b), 'e === ('a >= 'b))
+    checkCondition(('a <= 'b) =!= 'f, ('a > 'b) === 'f)
+    checkCondition('e =!= ('a <= 'b), 'e === ('a > 'b))
+
+    checkCondition('e =!= ('f === ('g === Not('h))), 'e === ('f === ('g === 'h)))
+
+  }
+
+  test("Properly avoid non optimize-able cases") {
+    checkCondition(Not(('a > 'b) <=> 'f), Not(('a > 'b) <=> 'f))
+    checkCondition(Not('e <=> ('a > 'b)), Not('e <=> ('a > 'b)))
+    checkCondition(('a === 'b) =!= ('a === 'c), ('a === 'b) =!= ('a === 'c))
+    checkCondition(('a === 'b) =!= ('c in(1, 2, 3)), ('a === 'b) =!= ('c in(1, 2, 3)))
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala
new file mode 100644
index 0000000..c9d1f33
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class NullDownPropagationSuite extends PlanTest with ExpressionEvalHelper {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("AnalysisNodes", Once, EliminateSubqueryAliases) ::
+      Batch("Null Down Propagation", FixedPoint(50),
+        NullPropagation,
+        NullDownPropagation,
+        ConstantFolding,
+        SimplifyConditionals,
+        BooleanSimplification,
+        NotPropagation,
+        PruneFilters) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string,
+    'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean)
+
+  private def checkCondition(input: Expression, expected: Expression): Unit = {
+    val plan = testRelation.where(input).analyze
+    val actual = Optimize.execute(plan)
+    val correctAnswer = testRelation.where(expected).analyze
+    comparePlans(actual, correctAnswer)
+  }
+
+  test("Using IsNull(e(inputs)) == IsNull(input1) or IsNull(input2) ... rules") {
+    checkCondition(IsNull(Not('e)), IsNull('e))
+    checkCondition(IsNotNull(Not('e)), IsNotNull('e))
+    checkCondition(IsNull('a > 'b), Or(IsNull('a), IsNull('b)))
+    checkCondition(IsNotNull('a > 'b), And(IsNotNull('a), IsNotNull('b)))
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org