You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/04/20 00:17:14 UTC

spark git commit: [SPARK-4226] [SQL] Support IN/EXISTS Subqueries

Repository: spark
Updated Branches:
  refs/heads/master 3c91afec2 -> da8859226


[SPARK-4226] [SQL] Support IN/EXISTS Subqueries

### What changes were proposed in this pull request?
This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms:

- `[NOT] EXISTS(subquery)`
- `[NOT] IN (subquery)`

This PR is (loosely) based on the work of davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/9055). They should be credited for the work they did.

### How was this patch tested?
Modified parsing unit tests.
Added tests to `org.apache.spark.sql.SQLQuerySuite`

cc rxin, davies & chenghao-intel

Author: Herman van Hovell <hv...@questtec.nl>

Closes #12306 from hvanhovell/SPARK-4226.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/da885922
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/da885922
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/da885922

Branch: refs/heads/master
Commit: da8859226e09aa6ebcf6a1c5c1369dec3c216eac
Parents: 3c91afe
Author: Herman van Hovell <hv...@questtec.nl>
Authored: Tue Apr 19 15:16:02 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Apr 19 15:16:02 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  30 +++--
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  40 ++++++-
 .../sql/catalyst/expressions/subquery.scala     |  84 +++++++++++++-
 .../sql/catalyst/optimizer/Optimizer.scala      | 115 ++++++++++++++++++-
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  16 ++-
 .../catalyst/analysis/AnalysisErrorSuite.scala  |  58 +++++++++-
 .../sql/catalyst/parser/ErrorParserSuite.scala  |   6 +-
 .../catalyst/parser/ExpressionParserSuite.scala |   8 +-
 .../sql/catalyst/parser/PlanParserSuite.scala   |   4 +-
 .../apache/spark/sql/execution/subquery.scala   |   6 +-
 .../scala/org/apache/spark/sql/QueryTest.scala  |  53 +++++++--
 .../org/apache/spark/sql/SubquerySuite.scala    |  98 ++++++++++++++++
 12 files changed, 476 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0e2fd43..2364769 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -855,25 +855,35 @@ class Analyzer(
   }
 
   /**
-   * This rule resolve subqueries inside expressions.
+   * This rule resolves sub-queries inside expressions.
    *
-   * Note: CTE are handled in CTESubstitution.
+   * Note: CTEs are handled in CTESubstitution.
    */
   object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
 
-    private def hasSubquery(e: Expression): Boolean = {
-      e.find(_.isInstanceOf[SubqueryExpression]).isDefined
-    }
-
-    private def hasSubquery(q: LogicalPlan): Boolean = {
-      q.expressions.exists(hasSubquery)
+    /**
+     * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a
+     * sub-query by using the plan the predicates should be correlated to.
+     */
+    private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = {
+      q transformUp {
+        case f @ Filter(cond, child) if child.resolved && !f.resolved =>
+          val newCond = resolveExpression(cond, p, throws = false)
+          if (!cond.fastEquals(newCond)) {
+            Filter(newCond, child)
+          } else {
+            f
+          }
+      }
     }
 
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case q: LogicalPlan if q.childrenResolved && hasSubquery(q) =>
+      case q: LogicalPlan if q.childrenResolved =>
         q transformExpressions {
           case e: SubqueryExpression if !e.query.resolved =>
-            e.withNewPlan(execute(e.query))
+            // First resolve as much of the sub-query as possible. After that we use the children of
+            // this plan to resolve the remaining correlated predicates.
+            e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates))
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d6a8c3e..45e4d53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.UsingJoin
+import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
 /**
  * Throws user facing errors when passed invalid queries that fail to analyze.
  */
-trait CheckAnalysis {
+trait CheckAnalysis extends PredicateHelper {
 
   /**
    * Override to provide additional checks for correct analysis.
@@ -110,6 +110,39 @@ trait CheckAnalysis {
               s"filter expression '${f.condition.sql}' " +
                 s"of type ${f.condition.dataType.simpleString} is not a boolean.")
 
+          case f @ Filter(condition, child) =>
+            // Make sure that no correlated reference is below Aggregates, Outer Joins and on the
+            // right hand side of Unions.
+            lazy val attributes = child.outputSet
+            def failOnCorrelatedReference(
+                p: LogicalPlan,
+                message: String): Unit = p.transformAllExpressions {
+              case e: NamedExpression if attributes.contains(e) =>
+                failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+            }
+            def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
+              case a @ Aggregate(_, _, source) =>
+                failOnCorrelatedReference(source, "an AGGREATE")
+              case j @ Join(left, _, RightOuter, _) =>
+                failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
+              case j @ Join(_, right, jt, _) if jt != Inner =>
+                failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN")
+              case Union(_ :: xs) =>
+                xs.foreach(failOnCorrelatedReference(_, "a UNION"))
+              case s: SetOperation =>
+                failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT")
+              case _ =>
+            }
+            splitConjunctivePredicates(condition).foreach {
+              case p: PredicateSubquery =>
+                checkForCorrelatedReferences(p)
+              case Not(p: PredicateSubquery) =>
+                checkForCorrelatedReferences(p)
+              case e if PredicateSubquery.hasPredicateSubquery(e) =>
+                failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
+              case e =>
+            }
+
           case j @ Join(_, _, UsingJoin(_, cols), _) =>
             val from = operator.inputSet.map(_.name).mkString(", ")
             failAnalysis(
@@ -209,6 +242,9 @@ trait CheckAnalysis {
                 | but one table has '${firstError.output.length}' columns and another table has
                 | '${s.children.head.output.length}' columns""".stripMargin)
 
+          case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
+            failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+
           case _ => // Fallbacks to the following checks
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 968bbdb..cbee0e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
 
 /**
  * An interface for subquery that is used in expressions.
  */
-abstract class SubqueryExpression extends LeafExpression {
+abstract class SubqueryExpression extends Expression {
 
   /**
    * The logical plan of the query.
@@ -61,6 +61,8 @@ case class ScalarSubquery(
 
   override def dataType: DataType = query.schema.fields.head.dataType
 
+  override def children: Seq[Expression] = Nil
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (query.schema.length != 1) {
       TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
@@ -77,3 +79,81 @@ case class ScalarSubquery(
 
   override def toString: String = s"subquery#${exprId.id}"
 }
+
+/**
+ * A predicate subquery checks the existence of a value in a sub-query. We currently only allow
+ * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
+ * be rewritten into a left semi/anti join during analysis.
+ */
+abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
+  override def nullable: Boolean = false
+  override def plan: LogicalPlan = SubqueryAlias(prettyName, query)
+}
+
+object PredicateSubquery {
+  def hasPredicateSubquery(e: Expression): Boolean = {
+    e.find(_.isInstanceOf[PredicateSubquery]).isDefined
+  }
+}
+
+/**
+ * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL):
+ * {{{
+ *   SELECT  *
+ *   FROM    a
+ *   WHERE   a.id IN (SELECT  id
+ *                    FROM    b)
+ * }}}
+ */
+case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery {
+  override def children: Seq[Expression] = value :: Nil
+  override lazy val resolved: Boolean = value.resolved && query.resolved
+  override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan)
+
+  /**
+   * The unwrapped value side expressions.
+   */
+  lazy val expressions: Seq[Expression] = value match {
+    case CreateStruct(cols) => cols
+    case col => Seq(col)
+  }
+
+  /**
+   * Check if the number of columns and the data types on both sides match.
+   */
+  override def checkInputDataTypes(): TypeCheckResult = {
+    // Check the number of arguments.
+    if (expressions.length != query.output.length) {
+      TypeCheckResult.TypeCheckFailure(
+        s"The number of fields in the value (${expressions.length}) does not match with " +
+          s"the number of columns in the subquery (${query.output.length})")
+    }
+
+    // Check the argument types.
+    expressions.zip(query.output).zipWithIndex.foreach {
+      case ((e, a), i) if e.dataType != a.dataType =>
+        TypeCheckResult.TypeCheckFailure(
+          s"The data type of value[$i](${e.dataType}) does not match " +
+            s"subquery column '${a.name}' (${a.dataType}).")
+      case _ =>
+    }
+
+    TypeCheckResult.TypeCheckSuccess
+  }
+}
+
+/**
+ * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
+ * For example (SQL):
+ * {{{
+ *   SELECT  *
+ *   FROM    a
+ *   WHERE   EXISTS (SELECT  *
+ *                   FROM    b
+ *                   WHERE   b.id = a.id)
+ * }}}
+ */
+case class Exists(query: LogicalPlan) extends PredicateSubquery {
+  override def children: Seq[Expression] = Nil
+  override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0a5232b..ecc2d77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.annotation.tailrec
 import scala.collection.immutable.HashSet
+import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
 import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
 import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
 import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
@@ -47,6 +48,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
     // However, because we also use the analyzer to canonicalized queries (for view definition),
     // we do not eliminate subqueries or compute current time in the analyzer.
     Batch("Finish Analysis", Once,
+      RewritePredicateSubquery,
       EliminateSubqueryAliases,
       ComputeCurrentTime,
       GetCurrentDatabase(sessionCatalog),
@@ -1446,3 +1448,114 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
       }
   }
 }
+
+/**
+ * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
+ * are supported:
+ * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
+ *    will be pulled out as the join conditions.
+ * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will
+ *    be pulled out as join conditions, value = selected column will also be used as join
+ *    condition.
+ */
+object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
+  /**
+   * Pull out all correlated predicates from a given sub-query. This method removes the correlated
+   * predicates from sub-query [[Filter]]s and adds the references of these predicates to
+   * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the
+   * predicates in the join condition.
+   *
+   * This method returns the rewritten sub-query and the combined (AND) extracted predicate.
+   */
+  private def pullOutCorrelatedPredicates(
+      subquery: LogicalPlan,
+      query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+    val references = query.outputSet
+    val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]]
+    val transformed = subquery transformUp {
+      case f @ Filter(cond, child) =>
+        // Find all correlated predicates.
+        val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
+          e.references.intersect(references).nonEmpty
+        }
+        // Rewrite the filter without the correlated predicates if any.
+        correlated match {
+          case Nil => f
+          case xs if local.nonEmpty =>
+            val newFilter = Filter(local.reduce(And), child)
+            predicateMap += newFilter -> correlated
+            newFilter
+          case xs =>
+            predicateMap += child -> correlated
+            child
+        }
+      case p @ Project(expressions, child) =>
+        // Find all pulled out predicates defined in the Project's subtree.
+        val localPredicates = p.collect(predicateMap).flatten
+
+        // Determine which correlated predicate references are missing from this project.
+        val localPredicateReferences = localPredicates
+          .map(_.references)
+          .reduceOption(_ ++ _)
+          .getOrElse(AttributeSet.empty)
+        val missingReferences = localPredicateReferences -- p.references -- query.outputSet
+
+        // Create a new project if we need to add missing references.
+        if (missingReferences.nonEmpty) {
+          Project(expressions ++ missingReferences, child)
+        } else {
+          p
+        }
+    }
+    (transformed, predicateMap.values.flatten.toSeq)
+  }
+
+  /**
+   * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by
+   * constructing the required join condition. Both the rewritten subquery and the constructed
+   * join condition are returned.
+   */
+  private def pullOutCorrelatedPredicates(
+      in: InSubQuery,
+      query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+    val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
+    val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
+    (resolved, conditions)
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case f @ Filter(condition, child) =>
+      val (withSubquery, withoutSubquery) =
+        splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
+
+      // Construct the pruned filter condition.
+      val newFilter: LogicalPlan = withoutSubquery match {
+        case Nil => child
+        case conditions => Filter(conditions.reduce(And), child)
+      }
+
+      // Filter the plan by applying left semi and left anti joins.
+      withSubquery.foldLeft(newFilter) {
+        case (p, Exists(sub)) =>
+          val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
+          Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+        case (p, Not(Exists(sub))) =>
+          val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
+          Join(p, resolved, LeftAnti, conditions.reduceOption(And))
+        case (p, in: InSubQuery) =>
+          val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+          Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+        case (p, Not(in: InSubQuery)) =>
+          val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+          // This is a NULL-aware (left) anti join (NAAJ).
+          // Construct the condition. A NULL in one of the conditions is regarded as a positive
+          // result; such a row will be filtered out by the Anti-Join operator.
+          val anyNull = conditions.map(IsNull).reduceLeft(Or)
+          val condition = conditions.reduceLeft(And)
+
+          // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
+          // if performance matters to you.
+          Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index aa59f3f..1c06762 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -391,9 +391,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
 
         // Having
         val withHaving = withProject.optional(having) {
-          // Note that we added a cast to boolean. If the expression itself is already boolean,
-          // the optimizer will get rid of the unnecessary cast.
-          Filter(Cast(expression(having), BooleanType), withProject)
+          // Note that we add a cast to non-predicate expressions. If the expression itself is
+          // already boolean, the optimizer will get rid of the unnecessary cast.
+          val predicate = expression(having) match {
+            case p: Predicate => p
+            case e => Cast(e, BooleanType)
+          }
+          Filter(predicate, withProject)
         }
 
         // Distinct
@@ -866,10 +870,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
   }
 
   /**
-   * Create a filtering correlated sub-query. This is not supported yet.
+   * Create a filtering correlated sub-query (EXISTS).
    */
   override def visitExists(ctx: ExistsContext): Expression = {
-    throw new ParseException("EXISTS clauses are not supported.", ctx)
+    Exists(plan(ctx.query))
   }
 
   /**
@@ -944,7 +948,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
           GreaterThanOrEqual(e, expression(ctx.lower)),
           LessThanOrEqual(e, expression(ctx.upper))))
       case SqlBaseParser.IN if ctx.query != null =>
-        throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+        invertIfNotDefined(InSubQuery(e, plan(ctx.query)))
       case SqlBaseParser.IN =>
         invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
       case SqlBaseParser.LIKE =>

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index ad101d1..a90636d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
 import org.apache.spark.sql.types._
 
@@ -444,4 +444,60 @@ class AnalysisErrorSuite extends AnalysisTest {
 
     assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
   }
+
+  test("PredicateSubQuery is used outside of a filter") {
+    val a = AttributeReference("a", IntegerType)()
+    val b = AttributeReference("b", IntegerType)()
+    val plan = Project(
+      Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()),
+      LocalRelation(a))
+    assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil)
+  }
+
+  test("PredicateSubQuery is used is a nested condition") {
+    val a = AttributeReference("a", IntegerType)()
+    val b = AttributeReference("b", IntegerType)()
+    val c = AttributeReference("c", BooleanType)()
+    val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a))
+    assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+
+    val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c))
+    assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+  }
+
+  test("PredicateSubQuery correlated predicate is nested in an illegal plan") {
+    val a = AttributeReference("a", IntegerType)()
+    val b = AttributeReference("b", IntegerType)()
+    val c = AttributeReference("c", IntegerType)()
+
+    val plan1 = Filter(
+      Exists(
+        Join(
+          LocalRelation(b),
+          Filter(EqualTo(a, c), LocalRelation(c)),
+          LeftOuter,
+          Option(EqualTo(b, c)))),
+      LocalRelation(a))
+    assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil)
+
+    val plan2 = Filter(
+      Exists(
+        Join(
+          Filter(EqualTo(a, c), LocalRelation(c)),
+          LocalRelation(b),
+          RightOuter,
+          Option(EqualTo(b, c)))),
+      LocalRelation(a))
+    assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil)
+
+    val plan3 = Filter(
+      Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))),
+      LocalRelation(a))
+    assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
+
+    val plan4 = Filter(
+      Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))),
+      LocalRelation(a))
+    assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
index db96bfb..6da3eae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
@@ -60,8 +60,8 @@ class ErrorParserSuite extends SparkFunSuite {
     intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
       "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
       "^^^")
-    intercept("select * from r where a in (select * from t)", 1, 24,
-      "IN with a Sub-query is currently not supported",
-      "------------------------^^^")
+    intercept("select * from r except all select * from t", 1, 0,
+      "EXCEPT ALL is not supported",
+      "^^^")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 6f40ec6..d1dc8d6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -113,7 +113,9 @@ class ExpressionParserSuite extends PlanTest {
   }
 
   test("exists expression") {
-    intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+    assertEqual(
+      "exists (select 1 from b where b.x = a.x)",
+      Exists(table("b").where(Symbol("b.x") === Symbol("a.x")).select(1)))
   }
 
   test("comparison expressions") {
@@ -139,7 +141,9 @@ class ExpressionParserSuite extends PlanTest {
   }
 
   test("in sub-query") {
-    intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+    assertEqual(
+      "a in (select b from c)",
+      InSubQuery('a, table("c").select('b)))
   }
 
   test("like expressions") {

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 411e237..a1ca55c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -107,7 +107,7 @@ class PlanParserSuite extends PlanTest {
     assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
     assertEqual(
       "select a, b from db.c having x < 1",
-      table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+      table("db", "c").select('a, 'b).where('x < 1))
     assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
     assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
   }
@@ -405,7 +405,7 @@ class PlanParserSuite extends PlanTest {
       "select g from t group by g having a > (select b from s)",
       table("t")
         .groupBy('g)('g)
-        .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+        .where('a > ScalarSubquery(table("s").select('b))))
   }
 
   test("table reference") {

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index b3e8b37..71b6a97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -18,8 +18,9 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -42,6 +43,7 @@ case class ScalarSubquery(
   override def plan: SparkPlan = Subquery(simpleString, executedPlan)
 
   override def dataType: DataType = executedPlan.schema.fields.head.dataType
+  override def children: Seq[Expression] = Nil
   override def nullable: Boolean = true
   override def toString: String = s"subquery#${exprId.id}"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 2dca792..cbacb5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import java.util.{Locale, TimeZone}
+import java.util.{ArrayDeque, Locale, TimeZone}
 
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
@@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.streaming.MemoryPlan
 import org.apache.spark.sql.types.ObjectType
 
+
+
 abstract class QueryTest extends PlanTest {
 
   protected def sqlContext: SQLContext
@@ -47,6 +49,7 @@ abstract class QueryTest extends PlanTest {
   /**
    * Runs the plan and makes sure the answer contains all of the keywords, or the
    * none of keywords are listed in the answer
+   *
    * @param df the [[DataFrame]] to be executed
    * @param exists true for make sure the keywords are listed in the output, otherwise
    *               to make sure none of the keyword are not listed in the output
@@ -119,6 +122,7 @@ abstract class QueryTest extends PlanTest {
 
   /**
    * Runs the plan and makes sure the answer matches the expected result.
+   *
    * @param df the [[DataFrame]] to be executed
    * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
    */
@@ -158,6 +162,7 @@ abstract class QueryTest extends PlanTest {
 
   /**
    * Runs the plan and makes sure the answer is within absTol of the expected result.
+   *
    * @param dataFrame the [[DataFrame]] to be executed
    * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
    * @param absTol the absolute tolerance between actual and expected answers.
@@ -198,7 +203,10 @@ abstract class QueryTest extends PlanTest {
   }
 
   private def checkJsonFormat(df: DataFrame): Unit = {
+    // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
+    // RDD and Data resolution does not break.
     val logicalPlan = df.queryExecution.analyzed
+
     // bypass some cases that we can't handle currently.
     logicalPlan.transform {
       case _: ObjectConsumer => return
@@ -236,9 +244,27 @@ abstract class QueryTest extends PlanTest {
     // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
     // these non-serializable stuff, and use these original ones to replace the null-placeholders
     // in the logical plans parsed from JSON.
-    var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
-    var localRelations = logicalPlan.collect { case l: LocalRelation => l }
-    var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+    val logicalRDDs = new ArrayDeque[LogicalRDD]()
+    val localRelations = new ArrayDeque[LocalRelation]()
+    val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
+    def collectData: (LogicalPlan => Unit) = {
+      case l: LogicalRDD =>
+        logicalRDDs.offer(l)
+      case l: LocalRelation =>
+        localRelations.offer(l)
+      case i: InMemoryRelation =>
+        inMemoryRelations.offer(i)
+      case p =>
+        p.expressions.foreach {
+          _.foreach {
+            case s: SubqueryExpression =>
+              s.query.foreach(collectData)
+            case _ =>
+          }
+        }
+    }
+    logicalPlan.foreach(collectData)
+
 
     val jsonBackPlan = try {
       TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
@@ -253,18 +279,15 @@ abstract class QueryTest extends PlanTest {
            """.stripMargin, e)
     }
 
-    val normalized2 = jsonBackPlan transformDown {
+    def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
       case l: LogicalRDD =>
-        val origin = logicalRDDs.head
-        logicalRDDs = logicalRDDs.drop(1)
+        val origin = logicalRDDs.pop()
         LogicalRDD(l.output, origin.rdd)(sqlContext)
       case l: LocalRelation =>
-        val origin = localRelations.head
-        localRelations = localRelations.drop(1)
+        val origin = localRelations.pop()
         l.copy(data = origin.data)
       case l: InMemoryRelation =>
-        val origin = inMemoryRelations.head
-        inMemoryRelations = inMemoryRelations.drop(1)
+        val origin = inMemoryRelations.pop()
         InMemoryRelation(
           l.output,
           l.useCompression,
@@ -275,7 +298,13 @@ abstract class QueryTest extends PlanTest {
           origin.cachedColumnBuffers,
           l._statistics,
           origin._batchStats)
+      case p =>
+        p.transformExpressions {
+          case s: SubqueryExpression =>
+            s.withNewPlan(s.query.transformDown(renormalize))
+        }
     }
+    val normalized2 = jsonBackPlan.transformDown(renormalize)
 
     assert(logicalRDDs.isEmpty)
     assert(localRelations.isEmpty)
@@ -309,6 +338,7 @@ object QueryTest {
    * If there was exception during the execution or the contents of the DataFrame does not
    * match the expected result, an error message will be returned. Otherwise, a [[None]] will
    * be returned.
+   *
    * @param df the [[DataFrame]] to be executed
    * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
    */
@@ -383,6 +413,7 @@ object QueryTest {
 
   /**
    * Runs the plan and makes sure the answer is within absTol of the expected result.
+   *
    * @param actualAnswer the actual result in a [[Row]].
    * @param expectedAnswer the expected result in a[[Row]].
    * @param absTol the absolute tolerance between actual and expected answers.

http://git-wip-us.apache.org/repos/asf/spark/blob/da885922/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 21b19fe..5742983 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,38 @@ import org.apache.spark.sql.test.SharedSQLContext
 class SubquerySuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
+  setupTestData()
+
+  val row = identity[(java.lang.Integer, java.lang.Double)](_)
+
+  lazy val l = Seq(
+    row(1, 2.0),
+    row(1, 2.0),
+    row(2, 1.0),
+    row(2, 1.0),
+    row(3, 3.0),
+    row(null, null),
+    row(null, 5.0),
+    row(6, null)).toDF("a", "b")
+
+  lazy val r = Seq(
+    row(2, 3.0),
+    row(2, 3.0),
+    row(3, 2.0),
+    row(4, 1.0),
+    row(null, null),
+    row(null, 5.0),
+    row(6, null)).toDF("c", "d")
+
+  lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    l.registerTempTable("l")
+    r.registerTempTable("r")
+    t.registerTempTable("t")
+  }
+
   test("simple uncorrelated scalar subquery") {
     assertResult(Array(Row(1))) {
       sql("select (select 1 as b) as b").collect()
@@ -80,4 +112,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
         " where key = (select max(key) from subqueryData) - 1)").collect()
     }
   }
+
+  test("EXISTS predicate subquery") {
+    checkAnswer(
+      sql("select * from l where exists(select * from r where l.a = r.c)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+  }
+
+  test("NOT EXISTS predicate subquery") {
+    checkAnswer(
+      sql("select * from l where not exists(select * from r where l.a = r.c)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
+      Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+  }
+
+  test("IN predicate subquery") {
+    checkAnswer(
+      sql("select * from l where l.a in (select c from r)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where l.a in (select c from r where l.b < r.d)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"),
+      Row(3, 3.0) :: Nil)
+  }
+
+  test("NOT IN predicate subquery") {
+    checkAnswer(
+      sql("select * from l where a not in(select c from r)"),
+      Nil)
+
+    checkAnswer(
+      sql("select * from l where a not in(select c from r where c is not null)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where a not in(select c from t where b < d)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil)
+
+    // Empty sub-query
+    checkAnswer(
+      sql("select * from l where a not in(select c from r where c > 10 and b < d)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) ::
+      Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+
+  }
+
+  test("complex IN predicate subquery") {
+    checkAnswer(
+      sql("select * from l where (a, b) not in(select c, d from r)"),
+      Nil)
+
+    checkAnswer(
+      sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org