You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/21 22:55:10 UTC

[GitHub] gatorsmile closed pull request #23350: [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False

gatorsmile closed pull request #23350: [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False
URL: https://github.com/apache/spark/pull/23350
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index 45edf266bbce4..08cf16038a654 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
  * Note:
  * Before flipping the filter condition of the right node, we should:
  * 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
  */
 object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
     plan.transform {
       case e @ Except(left, right) if isEligible(left, right) =>
-        val newCondition = transformCondition(left, skipProject(right))
-        newCondition.map { c =>
-          Distinct(Filter(Not(c), left))
-        }.getOrElse {
+        val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+        if (filterCondition.deterministic) {
+          transformCondition(left, filterCondition).map { c =>
+            Distinct(Filter(Not(c), left))
+          }.getOrElse {
+            e
+          }
+        } else {
           e
         }
     }
   }
 
-  private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
-    val filterCondition =
-      InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
-    val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap
-
-    if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
-      Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
+  private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
+    val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
+    if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+      val rewrittenCondition = condition.transform {
+        case a: AttributeReference => attributeNameMap(a.name)
+      }
+      // We need to consider as False when the condition is NULL, otherwise we do not return those
+      // rows containing NULL which are instead filtered in the Except right plan
+      Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
     } else {
       None
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 52dc2e9fb076c..78d3969906e99 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
 import org.apache.spark.sql.catalyst.expressions.aggregate.First
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
 
 class ReplaceOperatorSuite extends PlanTest {
 
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
           Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)), table1)).analyze
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
+          table1)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB), table1))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-          Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-            (attributeA >= 2 && attributeB < 1)),
+          Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
             Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA === 1 && attributeB === 2)),
+        Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB),
             Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
 
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
 
     comparePlans(optimized, query)
   }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+    val except = Except(basePlan, otherPlan)
+    val result = OptimizeIn(Optimize.execute(except.analyze))
+    val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+      Filter(!Coalesce(Seq(
+        'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)),
+        Literal.FalseLiteral)),
+        basePlan)).analyze
+    comparePlans(result, correctAnswer)
+  }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a > rand(1L))
+    val except = Except(basePlan, otherPlan)
+    val result = Optimize.execute(except.analyze)
+    val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
+      a1 <=> a2 }.reduce( _ && _)
+    val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+      Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+    comparePlans(result, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3b7bd84be047f..522ed8d4a42d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
       Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
   }
+
+  test("SPARK-26366: return nulls which are not filtered in except") {
+    val inputDF = sqlContext.createDataFrame(
+      sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+      StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true))))
+
+    val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+    checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+  }
 }
 
 case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0af6d87150324..6848b66514ef8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
     }
   }
+
+  test("SPARK-26366: verify ReplaceExceptWithFilter") {
+    Seq(true, false).foreach { enabled =>
+      withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+        val df = spark.createDataFrame(
+          sparkContext.parallelize(Seq(Row(0, 3, 5),
+            Row(0, 3, null),
+            Row(null, 3, 5),
+            Row(0, null, 5),
+            Row(0, null, null),
+            Row(null, null, 5),
+            Row(null, 3, null),
+            Row(null, null, null))),
+          StructType(Seq(StructField("c1", IntegerType),
+            StructField("c2", IntegerType),
+            StructField("c3", IntegerType))))
+        val where = "c2 >= 3 OR c1 >= 0"
+        val whereNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+            |OR (c1 IS NOT NULL AND c1 >= 0)
+          """.stripMargin
+
+        val df_a = df.filter(where)
+        val df_b = df.filter(whereNullSafe)
+        checkAnswer(df.except(df_a), df.except(df_b))
+
+        val whereWithIn = "c2 >= 3 OR c1 in (2)"
+        val whereWithInNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+          """.stripMargin
+        val dfIn_a = df.filter(whereWithIn)
+        val dfIn_b = df.filter(whereWithInNullSafe)
+        checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org