You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/20 03:37:47 UTC

[spark] branch branch-3.0 updated: [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 5ac9605  [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated
5ac9605 is described below

commit 5ac9605be3f7906b67fd796abb9ddd8a57d8c8c0
Author: allisonwang-db <66...@users.noreply.github.com>
AuthorDate: Tue Apr 20 11:11:40 2021 +0800

    [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated
    
    This PR updated the `foundNonEqualCorrelatedPred` logic for correlated subqueries in `CheckAnalysis` to only allow correlated equality predicates that guarantee one-to-one mapping between inner and outer attributes, instead of all equality predicates.
    
    To fix correctness bugs. Before this fix Spark can give wrong results for certain correlated subqueries that pass CheckAnalysis:
    Example 1:
    ```sql
    create or replace view t1(c) as values ('a'), ('b')
    create or replace view t2(c) as values ('ab'), ('abc'), ('bc')
    
    select c, (select count(*) from t2 where t1.c = substring(t2.c, 1, 1)) from t1
    ```
    Correct results: [(a, 2), (b, 1)]
    Spark results:
    ```
    +---+-----------------+
    |c  |scalarsubquery(c)|
    +---+-----------------+
    |a  |1                |
    |a  |1                |
    |b  |1                |
    +---+-----------------+
    ```
    Example 2:
    ```sql
    create or replace view t1(a, b) as values (0, 6), (1, 5), (2, 4), (3, 3);
    create or replace view t2(c) as values (6);
    
    select c, (select count(*) from t1 where a + b = c) from t2;
    ```
    Correct results: [(6, 4)]
    Spark results:
    ```
    +---+-----------------+
    |c  |scalarsubquery(c)|
    +---+-----------------+
    |6  |1                |
    |6  |1                |
    |6  |1                |
    |6  |1                |
    +---+-----------------+
    ```
    Yes. Users will not be able to run queries that contain unsupported correlated equality predicates.
    
    Added unit tests.
    
    Closes #32179 from allisonwang-db/spark-35080-subquery-bug.
    
    Lead-authored-by: allisonwang-db <66...@users.noreply.github.com>
    Co-authored-by: Wenchen Fan <cl...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit bad4b6f025de4946112a0897892a97d5ae6822cf)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 77 ++++++++++++++++++----
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 24 +++++++
 .../sql-tests/results/udf/udf-except.sql.out       | 12 +++-
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 11 +++-
 4 files changed, 109 insertions(+), 15 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index e5b8ae7..2887967 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -855,14 +855,72 @@ trait CheckAnalysis extends PredicateHelper {
     // +- SubqueryAlias t1, `t1`
     // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
     // +- LocalRelation [_1#73, _2#74]
-    def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
-      if (found) {
+    // SPARK-35080: The same issue can happen to correlated equality predicates when
+    // they do not guarantee one-to-one mapping between inner and outer attributes.
+    // For example:
+    // Table:
+    //   t1(a, b): [(0, 6), (1, 5), (2, 4)]
+    //   t2(c): [(6)]
+    //
+    // Query:
+    //   SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
+    //
+    // Original subquery plan:
+    //   Aggregate [count(1)]
+    //   +- Filter ((a + b) = outer(c))
+    //      +- LocalRelation [a, b]
+    //
+    // Plan after pulling up correlated predicates:
+    //   Aggregate [a, b] [count(1), a, b]
+    //   +- LocalRelation [a, b]
+    //
+    // Plan after rewrite:
+    //   Project [c1, count(1)]
+    //   +- Join LeftOuter ((a + b) = c)
+    //      :- LocalRelation [c]
+    //      +- Aggregate [a, b] [count(1), a, b]
+    //         +- LocalRelation [a, b]
+    //
+    // The right hand side of the join transformed from the subquery will output
+    //   count(1) | a | b
+    //      1     | 0 | 6
+    //      1     | 1 | 5
+    //      1     | 2 | 4
+    // and the plan after rewrite will give the original query incorrect results.
+    def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
+      if (predicates.nonEmpty) {
         // Report a non-supported case as an exception
-        failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
+        failAnalysis("Correlated column is not allowed in predicate " +
+          s"${predicates.map(_.sql).mkString}:\n$p")
       }
     }
 
-    var foundNonEqualCorrelatedPred: Boolean = false
+    def containsAttribute(e: Expression): Boolean = {
+      e.find(_.isInstanceOf[Attribute]).isDefined
+    }
+
+    // Given a correlated predicate, check if it is either a non-equality predicate or
+    // equality predicate that does not guarantee one-on-one mapping between inner and
+    // outer attributes. When the correlated predicate does not contain any attribute
+    // (i.e. only has outer references), it is supported and should return false. E.G.:
+    //   (a = outer(c)) -> false
+    //   (outer(c) = outer(d)) -> false
+    //   (a > outer(c)) -> true
+    //   (a + b = outer(c)) -> true
+    // The last one is true because there can be multiple combinations of (a, b) that
+    // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
+    // and (-1, 1) can make the predicate evaluate to true.
+    def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
+      // Only allow equality condition with one side being an attribute and another
+      // side being an expression without attributes from the inner query. Note
+      // OuterReference is a leaf node and will not be found here.
+      case Equality(_: Attribute, b) => containsAttribute(b)
+      case Equality(a, _: Attribute) => containsAttribute(a)
+      case e @ Equality(_, _) => containsAttribute(e)
+      case _ => true
+    }
+
+    val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]
 
     // Simplify the predicates before validating any unsupported correlation patterns in the plan.
     AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
@@ -905,22 +963,17 @@ trait CheckAnalysis extends PredicateHelper {
       // The other operator is Join. Filter can be anywhere in a correlated subquery.
       case f: Filter =>
         val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
-
-        // Find any non-equality correlated predicates
-        foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
-          case _: EqualTo | _: EqualNullSafe => false
-          case _ => true
-        }
+        unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
         failOnInvalidOuterReference(f)
 
       // Aggregate cannot host any correlated expressions
       // It can be on a correlation path if the correlation contains
-      // only equality correlated predicates.
+      // only supported correlated equality predicates.
       // It cannot be on a correlation path if the correlation has
       // non-equality correlated predicates.
       case a: Aggregate =>
         failOnInvalidOuterReference(a)
-        failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
+        failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)
 
       // Join can host correlated expressions.
       case j @ Join(left, right, joinType, _, _) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 09e0d9c..afc9780 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -698,4 +698,28 @@ class AnalysisErrorSuite extends AnalysisTest {
           UnresolvedRelation(TableIdentifier("t", Option("nonexist")))))))
     assertAnalysisError(plan, "Table or view not found:" :: Nil)
   }
+
+  test("SPARK-35080: Unsupported correlated equality predicates in subquery") {
+    val a = AttributeReference("a", IntegerType)()
+    val b = AttributeReference("b", IntegerType)()
+    val c = AttributeReference("c", IntegerType)()
+    val t1 = LocalRelation(a, b)
+    val t2 = LocalRelation(c)
+    val conditions = Seq(
+      (abs($"a") === $"c", "abs(`a`) = outer(`c`)"),
+      (abs($"a") <=> $"c", "abs(`a`) <=> outer(`c`)"),
+      ($"a" + 1 === $"c", "(`a` + 1) = outer(`c`)"),
+      ($"a" + $"b" === $"c", "(`a` + `b`) = outer(`c`)"),
+      ($"a" + $"c" === $"b", "(`a` + outer(`c`)) = `b`"),
+      (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(`a` AS INT) = outer(`c`)"))
+    conditions.foreach { case (cond, msg) =>
+      val plan = Project(
+        ScalarSubquery(
+          Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
+            Filter(cond, t1))
+        ).as("sub") :: Nil,
+        t2)
+      assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil)
+    }
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
index 054ee00..c8d7530 100644
--- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
@@ -100,6 +100,14 @@ WHERE  udf(t1.v) >= (SELECT   min(udf(t2.v))
                 FROM     t2
                 WHERE    t2.k = t1.k)
 -- !query schema
-struct<k:string>
+struct<>
 -- !query output
-two
+org.apache.spark.sql.AnalysisException
+Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)):
+Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS CAST(udf(cast(max(cast(udf(cast(v as string)) as int)) as string)) AS INT)#x]
++- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))
+   +- SubqueryAlias t2
+      +- Project [k#x, v#x]
+         +- SubqueryAlias t2
+            +- LocalRelation [k#x, v#x]
+;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 141f127..e369bc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -541,7 +541,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
     }
     assert(msg1.getMessage.contains(
-      "Correlated column is not allowed in a non-equality predicate:"))
+      "Correlated column is not allowed in predicate (l2.`a` < outer(l1.`a`))"))
   }
 
   test("disjunctive correlated scalar subquery") {
@@ -1646,4 +1646,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
     checkAnswer(df, df2)
     checkAnswer(df, Nil)
   }
+
+  test("SPARK-35080: correlated equality predicates contain only outer references") {
+    withTempView("t") {
+      Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t")
+      checkAnswer(
+        sql("select c1, c2, (select count(*) from l where c1 = c2) from t"),
+        Row(0, 1, 0) :: Row(1, 1, 8) :: Nil)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org