You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/20 03:27:04 UTC

[spark] branch branch-3.1 updated: [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 034ba76  [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated
034ba76 is described below

commit 034ba76a69101e6fb55d7dfaf48e3610415cf8a1
Author: allisonwang-db <66...@users.noreply.github.com>
AuthorDate: Tue Apr 20 11:11:40 2021 +0800

    [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated
    
    This PR updated the `foundNonEqualCorrelatedPred` logic for correlated subqueries in `CheckAnalysis` to only allow correlated equality predicates that guarantee one-to-one mapping between inner and outer attributes, instead of all equality predicates.
    
    To fix correctness bugs. Before this fix Spark can give wrong results for certain correlated subqueries that pass CheckAnalysis:
    Example 1:
    ```sql
    create or replace view t1(c) as values ('a'), ('b')
    create or replace view t2(c) as values ('ab'), ('abc'), ('bc')
    
    select c, (select count(*) from t2 where t1.c = substring(t2.c, 1, 1)) from t1
    ```
    Correct results: [(a, 2), (b, 1)]
    Spark results:
    ```
    +---+-----------------+
    |c  |scalarsubquery(c)|
    +---+-----------------+
    |a  |1                |
    |a  |1                |
    |b  |1                |
    +---+-----------------+
    ```
    Example 2:
    ```sql
    create or replace view t1(a, b) as values (0, 6), (1, 5), (2, 4), (3, 3);
    create or replace view t2(c) as values (6);
    
    select c, (select count(*) from t1 where a + b = c) from t2;
    ```
    Correct results: [(6, 4)]
    Spark results:
    ```
    +---+-----------------+
    |c  |scalarsubquery(c)|
    +---+-----------------+
    |6  |1                |
    |6  |1                |
    |6  |1                |
    |6  |1                |
    +---+-----------------+
    ```
    Yes. Users will not be able to run queries that contain unsupported correlated equality predicates.
    
    Added unit tests.
    
    Closes #32179 from allisonwang-db/spark-35080-subquery-bug.
    
    Lead-authored-by: allisonwang-db <66...@users.noreply.github.com>
    Co-authored-by: Wenchen Fan <cl...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit bad4b6f025de4946112a0897892a97d5ae6822cf)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 77 ++++++++++++++++++----
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 24 +++++++
 .../sql-tests/results/udf/udf-except.sql.out       | 12 +++-
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 11 +++-
 4 files changed, 109 insertions(+), 15 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3e084f0..3dfe7f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -891,14 +891,72 @@ trait CheckAnalysis extends PredicateHelper {
     // +- SubqueryAlias t1, `t1`
     // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
     // +- LocalRelation [_1#73, _2#74]
-    def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
-      if (found) {
+    // SPARK-35080: The same issue can happen to correlated equality predicates when
+    // they do not guarantee one-to-one mapping between inner and outer attributes.
+    // For example:
+    // Table:
+    //   t1(a, b): [(0, 6), (1, 5), (2, 4)]
+    //   t2(c): [(6)]
+    //
+    // Query:
+    //   SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
+    //
+    // Original subquery plan:
+    //   Aggregate [count(1)]
+    //   +- Filter ((a + b) = outer(c))
+    //      +- LocalRelation [a, b]
+    //
+    // Plan after pulling up correlated predicates:
+    //   Aggregate [a, b] [count(1), a, b]
+    //   +- LocalRelation [a, b]
+    //
+    // Plan after rewrite:
+    //   Project [c1, count(1)]
+    //   +- Join LeftOuter ((a + b) = c)
+    //      :- LocalRelation [c]
+    //      +- Aggregate [a, b] [count(1), a, b]
+    //         +- LocalRelation [a, b]
+    //
+    // The right hand side of the join transformed from the subquery will output
+    //   count(1) | a | b
+    //      1     | 0 | 6
+    //      1     | 1 | 5
+    //      1     | 2 | 4
+    // and the plan after rewrite will give the original query incorrect results.
+    def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
+      if (predicates.nonEmpty) {
         // Report a non-supported case as an exception
-        failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
+        failAnalysis("Correlated column is not allowed in predicate " +
+          s"${predicates.map(_.sql).mkString}:\n$p")
       }
     }
 
-    var foundNonEqualCorrelatedPred: Boolean = false
+    def containsAttribute(e: Expression): Boolean = {
+      e.find(_.isInstanceOf[Attribute]).isDefined
+    }
+
+    // Given a correlated predicate, check if it is either a non-equality predicate or
+    // equality predicate that does not guarantee one-on-one mapping between inner and
+    // outer attributes. When the correlated predicate does not contain any attribute
+    // (i.e. only has outer references), it is supported and should return false. E.G.:
+    //   (a = outer(c)) -> false
+    //   (outer(c) = outer(d)) -> false
+    //   (a > outer(c)) -> true
+    //   (a + b = outer(c)) -> true
+    // The last one is true because there can be multiple combinations of (a, b) that
+    // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
+    // and (-1, 1) can make the predicate evaluate to true.
+    def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
+      // Only allow equality condition with one side being an attribute and another
+      // side being an expression without attributes from the inner query. Note
+      // OuterReference is a leaf node and will not be found here.
+      case Equality(_: Attribute, b) => containsAttribute(b)
+      case Equality(a, _: Attribute) => containsAttribute(a)
+      case e @ Equality(_, _) => containsAttribute(e)
+      case _ => true
+    }
+
+    val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]
 
     // Simplify the predicates before validating any unsupported correlation patterns in the plan.
     AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
@@ -941,22 +999,17 @@ trait CheckAnalysis extends PredicateHelper {
       // The other operator is Join. Filter can be anywhere in a correlated subquery.
       case f: Filter =>
         val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
-
-        // Find any non-equality correlated predicates
-        foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
-          case _: EqualTo | _: EqualNullSafe => false
-          case _ => true
-        }
+        unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
         failOnInvalidOuterReference(f)
 
       // Aggregate cannot host any correlated expressions
       // It can be on a correlation path if the correlation contains
-      // only equality correlated predicates.
+      // only supported correlated equality predicates.
       // It cannot be on a correlation path if the correlation has
       // non-equality correlated predicates.
       case a: Aggregate =>
         failOnInvalidOuterReference(a)
-        failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
+        failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)
 
       // Join can host correlated expressions.
       case j @ Join(left, right, joinType, _, _) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 44128c4..20ba9c5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -700,4 +700,28 @@ class AnalysisErrorSuite extends AnalysisTest {
           UnresolvedRelation(TableIdentifier("t", Option("nonexist")))))))
     assertAnalysisError(plan, "Table or view not found:" :: Nil)
   }
+
+  test("SPARK-35080: Unsupported correlated equality predicates in subquery") {
+    val a = AttributeReference("a", IntegerType)()
+    val b = AttributeReference("b", IntegerType)()
+    val c = AttributeReference("c", IntegerType)()
+    val t1 = LocalRelation(a, b)
+    val t2 = LocalRelation(c)
+    val conditions = Seq(
+      (abs($"a") === $"c", "abs(`a`) = outer(`c`)"),
+      (abs($"a") <=> $"c", "abs(`a`) <=> outer(`c`)"),
+      ($"a" + 1 === $"c", "(`a` + 1) = outer(`c`)"),
+      ($"a" + $"b" === $"c", "(`a` + `b`) = outer(`c`)"),
+      ($"a" + $"c" === $"b", "(`a` + outer(`c`)) = `b`"),
+      (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(`a` AS INT) = outer(`c`)"))
+    conditions.foreach { case (cond, msg) =>
+      val plan = Project(
+        ScalarSubquery(
+          Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
+            Filter(cond, t1))
+        ).as("sub") :: Nil,
+        t2)
+      assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil)
+    }
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
index 054ee00..43506b4 100644
--- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out
@@ -100,6 +100,14 @@ WHERE  udf(t1.v) >= (SELECT   min(udf(t2.v))
                 FROM     t2
                 WHERE    t2.k = t1.k)
 -- !query schema
-struct<k:string>
+struct<>
 -- !query output
-two
+org.apache.spark.sql.AnalysisException
+Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)):
+Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS CAST(udf(cast(max(cast(udf(cast(v as string)) as int)) as string)) AS INT)#x]
++- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))
+   +- SubqueryAlias t2
+      +- View (`t2`, [k#x,v#x])
+         +- Project [k#x, v#x]
+            +- SubqueryAlias t2
+               +- LocalRelation [k#x, v#x]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 73b2349..fafe1bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -542,7 +542,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
     }
     assert(msg1.getMessage.contains(
-      "Correlated column is not allowed in a non-equality predicate:"))
+      "Correlated column is not allowed in predicate (l2.`a` < outer(l1.`a`))"))
   }
 
   test("disjunctive correlated scalar subquery") {
@@ -1753,4 +1753,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       }
     }
   }
+
+  test("SPARK-35080: correlated equality predicates contain only outer references") {
+    withTempView("t") {
+      Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t")
+      checkAnswer(
+        sql("select c1, c2, (select count(*) from l where c1 = c2) from t"),
+        Row(0, 1, 0) :: Row(1, 1, 8) :: Nil)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org