You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/08/24 13:47:15 UTC

spark git commit: [SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery

Repository: spark
Updated Branches:
  refs/heads/master 9e33954dd -> 183d4cb71


[SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery

## What changes were proposed in this pull request?

With the check for structural integrity proposed in SPARK-21726, it is found that the optimization rule `PullupCorrelatedPredicates` can produce unresolved plans.

For a correlated IN query looks like:

    SELECT t1.a FROM t1
    WHERE
    t1.a IN (SELECT t2.c
            FROM t2
            WHERE t1.b < t2.d);

The query plan might look like:

    Project [a#0]
    +- Filter a#0 IN (list#4 [b#1])
       :  +- Project [c#2]
       :     +- Filter (outer(b#1) < d#3)
       :        +- LocalRelation <empty>, [c#2, d#3]
       +- LocalRelation <empty>, [a#0, b#1]

After `PullupCorrelatedPredicates`, it produces query plan like:

    'Project [a#0]
    +- 'Filter a#0 IN (list#4 [(b#1 < d#3)])
       :  +- Project [c#2, d#3]
       :     +- LocalRelation <empty>, [c#2, d#3]
       +- LocalRelation <empty>, [a#0, b#1]

Because the correlated predicate involves another attribute `d#3` in subquery, it has been pulled out and added into the `Project` on the top of the subquery.

When `list` in `In` contains just one `ListQuery`, `In.checkInputDataTypes` checks if the size of `value` expressions matches the output size of subquery. In the above example, there is only `value` expression and the subquery output has two attributes `c#2, d#3`, so it fails the check and `In.resolved` returns `false`.

We should not let `In.checkInputDataTypes` wrongly report unresolved plans to fail the structural integrity check.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #18968 from viirya/SPARK-21759.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/183d4cb7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/183d4cb7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/183d4cb7

Branch: refs/heads/master
Commit: 183d4cb71fbcbf484fc85d8621e1fe04cbbc8195
Parents: 9e33954
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Aug 24 21:46:58 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Aug 24 21:46:58 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  6 +-
 .../sql/catalyst/analysis/TypeCoercion.scala    |  5 +-
 .../sql/catalyst/expressions/predicates.scala   | 65 +++++++++-----------
 .../sql/catalyst/expressions/subquery.scala     | 13 +++-
 .../spark/sql/catalyst/optimizer/subquery.scala | 10 +--
 .../PullupCorrelatedPredicatesSuite.scala       | 52 ++++++++++++++++
 .../negative-cases/subq-input-typecheck.sql.out |  6 +-
 7 files changed, 106 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 70a3885..1e934d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1286,8 +1286,10 @@ class Analyzer(
           resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, _, exprId) if !sub.resolved =>
           resolveSubQuery(e, plans)(Exists(_, _, exprId))
-        case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
-          val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
+        case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
+          val expr = resolveSubQuery(l, plans)((plan, exprs) => {
+            ListQuery(plan, exprs, exprId, plan.output)
+          })
           In(value, Seq(expr))
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 06d8350..9ffe646 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -402,7 +402,7 @@ object TypeCoercion {
 
       // Handle type casting required between value expression and subquery output
       // in IN subquery.
-      case i @ In(a, Seq(ListQuery(sub, children, exprId)))
+      case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
         if !i.resolved && flattenExpr(a).length == sub.output.length =>
         // LHS is the value expression of IN subquery.
         val lhs = flattenExpr(a)
@@ -434,7 +434,8 @@ object TypeCoercion {
             case _ => CreateStruct(castedLhs)
           }
 
-          In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
+          val newSub = Project(castedRhs, sub)
+          In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
         } else {
           i
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7bf10f1..613d620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -138,32 +138,33 @@ case class Not(child: Expression)
 case class In(value: Expression, list: Seq[Expression]) extends Predicate {
 
   require(list != null, "list should not be null")
+
   override def checkInputDataTypes(): TypeCheckResult = {
-    list match {
-      case ListQuery(sub, _, _) :: Nil =>
-        val valExprs = value match {
-          case cns: CreateNamedStruct => cns.valExprs
-          case expr => Seq(expr)
-        }
-        if (valExprs.length != sub.output.length) {
-          TypeCheckResult.TypeCheckFailure(
-            s"""
-               |The number of columns in the left hand side of an IN subquery does not match the
-               |number of columns in the output of subquery.
-               |#columns in left hand side: ${valExprs.length}.
-               |#columns in right hand side: ${sub.output.length}.
-               |Left side columns:
-               |[${valExprs.map(_.sql).mkString(", ")}].
-               |Right side columns:
-               |[${sub.output.map(_.sql).mkString(", ")}].
-             """.stripMargin)
-        } else {
-          val mismatchedColumns = valExprs.zip(sub.output).flatMap {
-            case (l, r) if l.dataType != r.dataType =>
-              s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
-            case _ => None
+    val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
+    if (mismatchOpt.isDefined) {
+      list match {
+        case ListQuery(_, _, _, childOutputs) :: Nil =>
+          val valExprs = value match {
+            case cns: CreateNamedStruct => cns.valExprs
+            case expr => Seq(expr)
           }
-          if (mismatchedColumns.nonEmpty) {
+          if (valExprs.length != childOutputs.length) {
+            TypeCheckResult.TypeCheckFailure(
+              s"""
+                 |The number of columns in the left hand side of an IN subquery does not match the
+                 |number of columns in the output of subquery.
+                 |#columns in left hand side: ${valExprs.length}.
+                 |#columns in right hand side: ${childOutputs.length}.
+                 |Left side columns:
+                 |[${valExprs.map(_.sql).mkString(", ")}].
+                 |Right side columns:
+                 |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
+          } else {
+            val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
+              case (l, r) if l.dataType != r.dataType =>
+                s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
+              case _ => None
+            }
             TypeCheckResult.TypeCheckFailure(
               s"""
                  |The data type of one or more elements in the left hand side of an IN subquery
@@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
                  |Left side:
                  |[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
                  |Right side:
-                 |[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
-               """.stripMargin)
-          } else {
-            TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
+                 |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
           }
-        }
-      case _ =>
-        val mismatchOpt = list.find(l => l.dataType != value.dataType)
-        if (mismatchOpt.isDefined) {
+        case _ =>
           TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
             s"${value.dataType} != ${mismatchOpt.get.dataType}")
-        } else {
-          TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
-        }
+      }
+    } else {
+      TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index d7b493d..c614604 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -274,9 +274,15 @@ object ScalarSubquery {
 case class ListQuery(
     plan: LogicalPlan,
     children: Seq[Expression] = Seq.empty,
-    exprId: ExprId = NamedExpression.newExprId)
+    exprId: ExprId = NamedExpression.newExprId,
+    childOutputs: Seq[Attribute] = Seq.empty)
   extends SubqueryExpression(plan, children, exprId) with Unevaluable {
-  override def dataType: DataType = plan.schema.fields.head.dataType
+  override def dataType: DataType = if (childOutputs.length > 1) {
+    childOutputs.toStructType
+  } else {
+    childOutputs.head.dataType
+  }
+  override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
   override def nullable: Boolean = false
   override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
   override def toString: String = s"list#${exprId.id} $conditionString"
@@ -284,7 +290,8 @@ case class ListQuery(
     ListQuery(
       plan.canonicalized,
       children.map(_.canonicalized),
-      ExprId(0))
+      ExprId(0),
+      childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 9dbb6b1..4386a10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
         case (p, Not(Exists(sub, conditions, _))) =>
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
           Join(outerPlan, sub, LeftAnti, joinCond)
-        case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
+        case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
           val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
           Join(outerPlan, sub, LeftSemi, joinCond)
-        case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
+        case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
           // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
           // Construct the condition. A NULL in one of the conditions is regarded as a positive
           // result; such a row will be filtered out by the Anti-Join operator.
@@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
           newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
           exists
-        case In(value, Seq(ListQuery(sub, conditions, _))) =>
+        case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
           val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
           val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
@@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
       case Exists(sub, children, exprId) if children.nonEmpty =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
         Exists(newPlan, newCond, exprId)
-      case ListQuery(sub, _, exprId) =>
+      case ListQuery(sub, _, exprId, childOutputs) =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
-        ListQuery(newPlan, newCond, exprId)
+        ListQuery(newPlan, newCond, exprId, childOutputs)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
new file mode 100644
index 0000000..169b8737
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class PullupCorrelatedPredicatesSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("PullupCorrelatedPredicates", Once,
+        PullupCorrelatedPredicates) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.double)
+  val testRelation2 = LocalRelation('c.int, 'd.double)
+
+  test("PullupCorrelatedPredicates should not produce unresolved plan") {
+    val correlatedSubquery =
+      testRelation2
+        .where('b < 'd)
+        .select('c)
+    val outerQuery =
+      testRelation
+        .where(In('a, Seq(ListQuery(correlatedSubquery))))
+        .select('a).analyze
+    assert(outerQuery.resolved)
+
+    val optimized = Optimize.execute(outerQuery)
+    assert(optimized.resolved)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/183d4cb7/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
index 9ea9d3c..70aeb93 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
@@ -80,8 +80,7 @@ number of columns in the output of subquery.
 Left side columns:
 [t1.`t1a`].
 Right side columns:
-[t2.`t2a`, t2.`t2b`].
-             ;
+[t2.`t2a`, t2.`t2b`].;
 
 
 -- !query 6
@@ -102,5 +101,4 @@ number of columns in the output of subquery.
 Left side columns:
 [t1.`t1a`, t1.`t1b`].
 Right side columns:
-[t2.`t2a`].
-             ;
+[t2.`t2a`].;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org