You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/12/21 22:52:46 UTC

[spark] branch branch-2.3 updated: [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False

This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new a7d50ae  [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False
a7d50ae is described below

commit a7d50ae24a5f92e8d9b6622436f0bb4c2e06cbe1
Author: Marco Gaido <ma...@gmail.com>
AuthorDate: Fri Dec 21 14:52:29 2018 -0800

    [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False
    
    ## What changes were proposed in this pull request?
    
    In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions.
    
    The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output.
    
    The PR fixes these problem by:
     - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it);
     - avoiding any transformation when the condition is non-deterministic.
    
    ## How was this patch tested?
    
    added UTs
    
    Closes #23350 from mgaido91/SPARK-26366_2.3.
    
    Authored-by: Marco Gaido <ma...@gmail.com>
    Signed-off-by: gatorsmile <ga...@gmail.com>
---
 .../optimizer/ReplaceExceptWithFilter.scala        | 32 +++++++++-------
 .../catalyst/optimizer/ReplaceOperatorSuite.scala  | 44 ++++++++++++++++------
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 11 ++++++
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 38 +++++++++++++++++++
 4 files changed, 101 insertions(+), 24 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index 45edf26..08cf160 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
  * Note:
  * Before flipping the filter condition of the right node, we should:
  * 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
  */
 object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
     plan.transform {
       case e @ Except(left, right) if isEligible(left, right) =>
-        val newCondition = transformCondition(left, skipProject(right))
-        newCondition.map { c =>
-          Distinct(Filter(Not(c), left))
-        }.getOrElse {
+        val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+        if (filterCondition.deterministic) {
+          transformCondition(left, filterCondition).map { c =>
+            Distinct(Filter(Not(c), left))
+          }.getOrElse {
+            e
+          }
+        } else {
           e
         }
     }
   }
 
-  private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
-    val filterCondition =
-      InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
-    val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap
-
-    if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
-      Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
+  private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
+    val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
+    if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+      val rewrittenCondition = condition.transform {
+        case a: AttributeReference => attributeNameMap(a.name)
+      }
+      // We need to consider as False when the condition is NULL, otherwise we do not return those
+      // rows containing NULL which are instead filtered in the Except right plan
+      Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
     } else {
       None
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 52dc2e9..78d3969 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
 import org.apache.spark.sql.catalyst.expressions.aggregate.First
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
 
 class ReplaceOperatorSuite extends PlanTest {
 
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
           Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)), table1)).analyze
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
+          table1)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB), table1))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-          Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-            (attributeA >= 2 && attributeB < 1)),
+          Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
             Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA === 1 && attributeB === 2)),
+        Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB),
             Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
 
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
 
     comparePlans(optimized, query)
   }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+    val except = Except(basePlan, otherPlan)
+    val result = OptimizeIn(Optimize.execute(except.analyze))
+    val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+      Filter(!Coalesce(Seq(
+        'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)),
+        Literal.FalseLiteral)),
+        basePlan)).analyze
+    comparePlans(result, correctAnswer)
+  }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a > rand(1L))
+    val except = Except(basePlan, otherPlan)
+    val result = Optimize.execute(except.analyze)
+    val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
+      a1 <=> a2 }.reduce( _ && _)
+    val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+      Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+    comparePlans(result, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3b7bd84..522ed8d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
       Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
   }
+
+  test("SPARK-26366: return nulls which are not filtered in except") {
+    val inputDF = sqlContext.createDataFrame(
+      sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+      StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true))))
+
+    val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+    checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+  }
 }
 
 case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0af6d87..6848b66 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
     }
   }
+
+  test("SPARK-26366: verify ReplaceExceptWithFilter") {
+    Seq(true, false).foreach { enabled =>
+      withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+        val df = spark.createDataFrame(
+          sparkContext.parallelize(Seq(Row(0, 3, 5),
+            Row(0, 3, null),
+            Row(null, 3, 5),
+            Row(0, null, 5),
+            Row(0, null, null),
+            Row(null, null, 5),
+            Row(null, 3, null),
+            Row(null, null, null))),
+          StructType(Seq(StructField("c1", IntegerType),
+            StructField("c2", IntegerType),
+            StructField("c3", IntegerType))))
+        val where = "c2 >= 3 OR c1 >= 0"
+        val whereNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+            |OR (c1 IS NOT NULL AND c1 >= 0)
+          """.stripMargin
+
+        val df_a = df.filter(where)
+        val df_b = df.filter(whereNullSafe)
+        checkAnswer(df.except(df_a), df.except(df_b))
+
+        val whereWithIn = "c2 >= 3 OR c1 in (2)"
+        val whereWithInNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+          """.stripMargin
+        val dfIn_a = df.filter(whereWithIn)
+        val dfIn_b = df.filter(whereWithInNullSafe)
+        checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org