You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/07/27 02:39:09 UTC

[spark] branch master updated: [SPARK-28441][SQL][PYTHON] Fix error when non-foldable expression is used in correlated scalar subquery

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 558dd23  [SPARK-28441][SQL][PYTHON] Fix error when non-foldable expression is used in correlated scalar subquery
558dd23 is described below

commit 558dd2360163250e9fb55c3a49f87c907b65ea0d
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sat Jul 27 10:38:34 2019 +0800

    [SPARK-28441][SQL][PYTHON] Fix error when non-foldable expression is used in correlated scalar subquery
    
    ## What changes were proposed in this pull request?
    
    In SPARK-15370, We checked the expression at the root of the correlated subquery, in order to fix count bug. If a `PythonUDF` in in the checking path, evaluating it causes the failure as we can't statically evaluate `PythonUDF`. The Python UDF test added at SPARK-28277 shows this issue.
    
    If we can statically evaluate the expression, we intercept NULL values coming from the outer join and replace them with the value that the subquery's expression like before, if it is not, we replace them with the `PythonUDF` expression, with statically evaluated parameters.
    
    After this, the last query in `udf-except.sql` which throws `java.lang.UnsupportedOperationException` can be run:
    
    ```
    SELECT t1.k
    FROM   t1
    WHERE  t1.v <= (SELECT   udf(max(udf(t2.v)))
                    FROM     t2
                    WHERE    udf(t2.k) = udf(t1.k))
    MINUS
    SELECT t1.k
    FROM   t1
    WHERE  udf(t1.v) >= (SELECT   min(udf(t2.v))
                    FROM     t2
                    WHERE    t2.k = t1.k)
    -- !query 2 schema
    struct<k:string>
    -- !query 2 output
    two
    ```
    
    Note that this issue is also for other non-foldable expressions, like rand. As like PythonUDF, we can't call `eval` on this kind of expressions in optimization. The evaluation needs to defer to query runtime.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #25204 from viirya/SPARK-28441.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   2 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   4 +-
 .../spark/sql/catalyst/optimizer/subquery.scala    |  82 +++++---
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 227 +++++++++++++++++++++
 4 files changed, 290 insertions(+), 25 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 021fb26..5bf4dc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2833,7 +2833,7 @@ object EliminateUnions extends Rule[LogicalPlan] {
  * rule can't work for those parameters.
  */
 object CleanupAliases extends Rule[LogicalPlan] {
-  private def trimAliases(e: Expression): Expression = {
+  def trimAliases(e: Expression): Expression = {
     e.transformDown {
       case Alias(child, _) => child
       case MultiAlias(child, _) => child
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index af90ef4..1c36cdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -675,7 +675,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
    */
   private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
     case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
-      if p2.outputSet.subsetOf(child.outputSet) =>
+      if p2.outputSet.subsetOf(child.outputSet) &&
+        // We only remove attribute-only project.
+        p2.projectList.forall(_.isInstanceOf[AttributeReference]) =>
       p1.copy(child = f.copy(child = child))
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index e78ed1c..4f7333c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.CleanupAliases
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -317,24 +318,40 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
   }
 
   /**
+   * Checks if given expression is foldable. Evaluates it and returns it as literal, if yes.
+   * If not, returns the original expression without evaluation.
+   */
+  private def tryEvalExpr(expr: Expression): Expression = {
+    // Removes Alias over given expression, because Alias is not foldable.
+    if (!CleanupAliases.trimAliases(expr).foldable) {
+      // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
+      // Needs to evaluate them on query runtime.
+      expr
+    } else {
+      Literal.create(expr.eval(), expr.dataType)
+    }
+  }
+
+  /**
    * Statically evaluate an expression containing zero or more placeholders, given a set
-   * of bindings for placeholder values.
+   * of bindings for placeholder values, if the expression is evaluable. If it is not,
+   * bind statically evaluated expression results to an expression.
    */
-  private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
+  private def bindingExpr(
+      expr: Expression,
+      bindings: Map[ExprId, Expression]): Expression = {
     val rewrittenExpr = expr transform {
       case r: AttributeReference =>
-        bindings(r.exprId) match {
-          case Some(v) => Literal.create(v, r.dataType)
-          case None => Literal.default(NullType)
-        }
+        bindings.getOrElse(r.exprId, Literal.default(NullType))
     }
-    Option(rewrittenExpr.eval())
+
+    tryEvalExpr(rewrittenExpr)
   }
 
   /**
    * Statically evaluate an expression containing one or more aggregates on an empty input.
    */
-  private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
+  private def evalAggOnZeroTups(expr: Expression) : Expression = {
     // AggregateExpressions are Unevaluable, so we need to replace all aggregates
     // in the expression with the value they would return for zero input tuples.
     // Also replace attribute refs (for example, for grouping columns) with NULL.
@@ -344,7 +361,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
 
       case _: AttributeReference => Literal.default(NullType)
     }
-    Option(rewrittenExpr.eval())
+
+    tryEvalExpr(rewrittenExpr)
   }
 
   /**
@@ -354,19 +372,33 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
    * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
    * CheckAnalysis become less restrictive, this method will need to change.
    */
-  private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
+  private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = {
     // Inputs to this method will start with a chain of zero or more SubqueryAlias
     // and Project operators, followed by an optional Filter, followed by an
     // Aggregate. Traverse the operators recursively.
-    def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
+    def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match {
       case SubqueryAlias(_, child) => evalPlan(child)
       case Filter(condition, child) =>
         val bindings = evalPlan(child)
-        if (bindings.isEmpty) bindings
-        else {
-          val exprResult = evalExpr(condition, bindings).getOrElse(false)
-            .asInstanceOf[Boolean]
-          if (exprResult) bindings else Map.empty
+        if (bindings.isEmpty) {
+          bindings
+        } else {
+          val bindCondition = bindingExpr(condition, bindings)
+
+          if (!bindCondition.foldable) {
+            // We can't evaluate the condition. Evaluate it in query runtime.
+            bindings.map { case (id, expr) =>
+              val newExpr = If(bindCondition, expr, Literal.create(null, expr.dataType))
+              (id, newExpr)
+            }
+          } else {
+            // The bound condition can be evaluated.
+            bindCondition.eval() match {
+              // For filter condition, null is the same as false.
+              case null | false => Map.empty
+              case true => bindings
+            }
+          }
         }
 
       case Project(projectList, child) =>
@@ -374,7 +406,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
         if (bindings.isEmpty) {
           bindings
         } else {
-          projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
+          projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
         }
 
       case Aggregate(_, aggExprs, _) =>
@@ -382,8 +414,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
         // for joining with the outer query block. Fill those expressions in with
         // nulls and statically evaluate the remainder.
         aggExprs.map {
-          case ref: AttributeReference => (ref.exprId, None)
-          case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
+          case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
+          case alias @ Alias(_: AttributeReference, _) =>
+            (alias.exprId, Literal.create(null, alias.dataType))
           case ne => (ne.exprId, evalAggOnZeroTups(ne))
         }.toMap
 
@@ -394,7 +427,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
     val resultMap = evalPlan(plan)
 
     // By convention, the scalar subquery result is the leftmost field.
-    resultMap.getOrElse(plan.output.head.exprId, None)
+    resultMap.get(plan.output.head.exprId) match {
+      case Some(Literal(null, _)) | None => None
+      case o => o
+    }
   }
 
   /**
@@ -473,7 +509,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
               currentChild.output :+
                 Alias(
                   If(IsNull(alwaysTrueRef),
-                    Literal.create(resultWithZeroTups.get, origOutput.dataType),
+                    resultWithZeroTups.get,
                     aggValRef), origOutput.name)(exprId = origOutput.exprId),
               Join(currentChild,
                 Project(query.output :+ alwaysTrueExpr, query),
@@ -494,11 +530,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
               case op => sys.error(s"Unexpected operator $op in corelated subquery")
             }
 
-            // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
+            // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
             //      WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
             //      ELSE (aggregate value) END AS (original column name)
             val caseExpr = Alias(CaseWhen(Seq(
-              (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
+              (IsNull(alwaysTrueRef), resultWithZeroTups.get),
               (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
               aggValRef),
               origOutput.name)(exprId = origOutput.exprId)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index b2c3868..4ec85b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -1384,4 +1384,231 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
     assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
           "SubqueryExec name should start with scalar-subquery#")
   }
+
+  test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    // Case 1: Canonical example of the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) < l.a"),
+      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
+    // a rewrite that is vulnerable to the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+    // Case 3: COUNT bug without a COUNT aggregate
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE l.a = r.c)"),
+      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM l"),
+      Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
+        :: Row(null, 0) :: Row(6, 1) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a AS grp_a
+            |FROM l GROUP BY l.a
+            |HAVING
+            |  (
+            |    SELECT udf(count(*)) FROM r WHERE grp_a = r.c
+            |  ) = 0
+            |ORDER BY grp_a""".stripMargin),
+      Row(null) :: Row(1) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a AS aval,
+            |  sum(
+            |    (
+            |      SELECT udf(count(*)) FROM r WHERE l.a = r.c
+            |    )
+            |  ) AS cnt
+            |FROM l GROUP BY l.a ORDER BY aval""".stripMargin),
+      Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1)  :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug negative examples with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    // Case 1: Potential COUNT bug case that was working correctly prior to the fix
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) FROM r WHERE l.a = r.c) is null"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
+    // Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) > 0"),
+      Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
+    // Case 3: COUNT inside aggregate expression but no COUNT bug.
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a
+            |FROM l
+            |WHERE
+            |  (
+            |    SELECT udf(count(*)) + udf(sum(r.d))
+            |    FROM r WHERE l.a = r.c
+            |  ) = 0""".stripMargin),
+      Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in nested subquery with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT l.a FROM l
+            |WHERE (
+            |    SELECT cntPlusOne + 1 AS cntPlusTwo FROM (
+            |        SELECT cnt + 1 AS cntPlusOne FROM (
+            |            SELECT udf(sum(r.c)) s, udf(count(*)) cnt FROM r WHERE l.a = r.c
+            |                   HAVING cnt = 0
+            |        )
+            |    )
+            |) = 2""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a
+            |FROM l WHERE
+            |  (
+            |    SELECT CASE WHEN udf(count(*)) = 1 THEN null ELSE udf(count(*)) END AS cnt
+            |    FROM r WHERE l.a = r.c
+            |  ) = 0""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT
+          |  l.b,
+          |  (
+          |    SELECT (r.c + udf(count(*))) is null
+          |    FROM r
+          |    WHERE l.a = r.c GROUP BY r.c
+          |  )
+          |FROM l
+        """.stripMargin),
+      Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
+        Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with non-foldable expression") {
+    // Case 1: Canonical example of the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " +
+        "WHERE l.a = r.c) < l.a"),
+      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
+    // a rewrite that is vulnerable to the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " +
+        "WHERE l.a = r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+    // Case 3: COUNT bug without a COUNT aggregate
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT sum(r.d) is null from r " +
+        "WHERE l.a = r.c)"),
+      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in nested subquery with non-foldable expr") {
+    checkAnswer(
+      sql("""
+            |SELECT l.a FROM l
+            |WHERE (
+            |  SELECT cntPlusOne + 1 AS cntPlusTwo FROM (
+            |    SELECT cnt + 1 AS cntPlusOne FROM (
+            |      SELECT sum(r.c) s, (count(*) + cast(rand() as int)) cnt FROM r
+            |        WHERE l.a = r.c HAVING cnt = 0
+            |      )
+            |  )
+            |) = 2""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") {
+    val df = sql("""
+                   |SELECT
+                   |  l.a
+                   |FROM l WHERE
+                   |  (
+                   |    SELECT cntPlusOne + 1 as cntPlusTwo FROM
+                   |    (
+                   |      SELECT cnt + 1 as cntPlusOne FROM
+                   |      (
+                   |        SELECT sum(r.c) s, count(*) cnt FROM r WHERE l.a = r.c HAVING cnt > 0
+                   |      )
+                   |    )
+                   |  ) = 2""".stripMargin)
+    val df2 = sql("""
+                    |SELECT
+                    |  l.a
+                    |FROM l WHERE
+                    |  (
+                    |    SELECT cntPlusOne + 1 AS cntPlusTwo
+                    |    FROM
+                    |      (
+                    |        SELECT cnt + 1 AS cntPlusOne
+                    |        FROM
+                    |          (
+                    |            SELECT sum(r.c) s, count(*) cnt FROM r
+                    |            WHERE l.a = r.c HAVING (cnt + cast(rand() as int)) > 0
+                    |          )
+                    |       )
+                    |   ) = 2""".stripMargin)
+    checkAnswer(df, df2)
+    checkAnswer(df, Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org