You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/01 07:22:58 UTC

[spark] branch branch-3.0 updated: [SPARK-38180][SQL][3.1] Allow safe up-cast expressions in correlated equality predicates

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 5b3f567  [SPARK-38180][SQL][3.1] Allow safe up-cast expressions in correlated equality predicates
5b3f567 is described below

commit 5b3f5672a679207b0089a5e421e47b16a4e6b4d3
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Mar 1 15:13:50 2022 +0800

    [SPARK-38180][SQL][3.1] Allow safe up-cast expressions in correlated equality predicates
    
    Backport https://github.com/apache/spark/pull/35486 to branch-3.1.
    
    ### What changes were proposed in this pull request?
    
    This PR relaxes the constraint added in [SPARK-35080](https://issues.apache.org/jira/browse/SPARK-35080) by allowing safe up-cast expressions in correlated equality predicates.
    
    ### Why are the changes needed?
    
    Cast expressions are often added by the compiler during query analysis. Correlated equality predicates can be less restrictive to support this common pattern if a cast expression guarantees one-to-one mapping between the child expression and the output datatype (safe up-cast).
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Safe up-cast expressions are allowed in correlated equality predicates:
    ```sql
    SELECT (SELECT SUM(b) FROM VALUES (1, 1), (1, 2) t(a, b) WHERE CAST(a AS STRING) = x)
    FROM VALUES ('1'), ('2') t(x)
    ```
    Before this change, this query will throw AnalysisException "Correlated column is not allowed in predicate...", and after this change, this query can run successfully.
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #35689 from allisonwang-db/spark-38180-3.1.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 1e1f6b2aac5091343d572fb2472f46fa574882eb)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 25 ++++++++++++++-------
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala |  5 +++--
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 26 ++++++++++++++++++++++
 3 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fe12dd4..a35650f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -913,14 +913,23 @@ trait CheckAnalysis extends PredicateHelper {
     // The last one is true because there can be multiple combinations of (a, b) that
     // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
     // and (-1, 1) can make the predicate evaluate to true.
-    def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
-      // Only allow equality condition with one side being an attribute and another
-      // side being an expression without attributes from the inner query. Note
-      // OuterReference is a leaf node and will not be found here.
-      case Equality(_: Attribute, b) => containsAttribute(b)
-      case Equality(a, _: Attribute) => containsAttribute(a)
-      case e @ Equality(_, _) => containsAttribute(e)
-      case _ => true
+    def isUnsupportedPredicate(condition: Expression): Boolean = {
+      def isSupported(e: Expression): Boolean = e match {
+        case _: Attribute => true
+        // SPARK-38180: Allow Cast expressions that guarantee 1:1 mapping.
+        case Cast(a: Attribute, dataType, _) => Cast.canUpCast(a.dataType, dataType)
+        case _ => false
+      }
+
+      condition match {
+        // Only allow equality condition with one side being an attribute and another
+        // side being an expression without attributes from the inner query. Note
+        // OuterReference is a leaf node and will not be found here.
+        case Equality(a, b) if isSupported(a) => containsAttribute(b)
+        case Equality(a, b) if isSupported(b) => containsAttribute(a)
+        case e @ Equality(_, _) => containsAttribute(e)
+        case _ => true
+      }
     }
 
     val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 348c282..b549e77 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -703,7 +703,8 @@ class AnalysisErrorSuite extends AnalysisTest {
     val a = AttributeReference("a", IntegerType)()
     val b = AttributeReference("b", IntegerType)()
     val c = AttributeReference("c", IntegerType)()
-    val t1 = LocalRelation(a, b)
+    val d = AttributeReference("d", DoubleType)()
+    val t1 = LocalRelation(a, b, d)
     val t2 = LocalRelation(c)
     val conditions = Seq(
       (abs($"a") === $"c", "abs(`a`) = outer(`c`)"),
@@ -711,7 +712,7 @@ class AnalysisErrorSuite extends AnalysisTest {
       ($"a" + 1 === $"c", "(`a` + 1) = outer(`c`)"),
       ($"a" + $"b" === $"c", "(`a` + `b`) = outer(`c`)"),
       ($"a" + $"c" === $"b", "(`a` + outer(`c`)) = `b`"),
-      (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(`a` AS INT) = outer(`c`)"))
+      (And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(`d` AS INT) = outer(`c`)"))
     conditions.foreach { case (cond, msg) =>
       val plan = Project(
         ScalarSubquery(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index e369bc9..df638c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -1655,4 +1655,30 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
         Row(0, 1, 0) :: Row(1, 1, 8) :: Nil)
     }
   }
+
+  test("SPARK-38180: allow safe cast expressions in correlated equality conditions") {
+    withTempView("t1", "t2") {
+      Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
+      Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2")
+      checkAnswer(sql(
+        """
+          |SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a)
+          |FROM (SELECT CAST(c1 AS DOUBLE) a FROM t1)
+          |""".stripMargin),
+        Row(5) :: Row(null) :: Nil)
+      checkAnswer(sql(
+        """
+          |SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS STRING) = a)
+          |FROM (SELECT CAST(c1 AS STRING) a FROM t1)
+          |""".stripMargin),
+        Row(5) :: Row(null) :: Nil)
+      assert(intercept[AnalysisException] {
+        sql(
+          """
+            |SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a)
+            |FROM (SELECT CAST(c1 AS SHORT) a FROM t1)
+            |""".stripMargin)
+      }.getMessage.contains("Correlated column is not allowed in predicate"))
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org