You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/29 20:22:16 UTC

spark git commit: [SPARK-12656] [SQL] Implement Intersect with Left-semi Join

Repository: spark
Updated Branches:
  refs/heads/master c5f745ede -> 5f686cc8b


[SPARK-12656] [SQL] Implement Intersect with Left-semi Join

Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins).

After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566

Author: gatorsmile <ga...@gmail.com>
Author: xiaoli <li...@gmail.com>
Author: Xiao Li <xi...@Xiaos-MacBook-Pro.local>

Closes #10630 from gatorsmile/IntersectBySemiJoin.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f686cc8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f686cc8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f686cc8

Branch: refs/heads/master
Commit: 5f686cc8b74ea9e36f56c31f14df90d134fd9343
Parents: c5f745e
Author: gatorsmile <ga...@gmail.com>
Authored: Fri Jan 29 11:22:12 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Jan 29 11:22:12 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 113 ++++++++++---------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  14 ++-
 .../sql/catalyst/optimizer/Optimizer.scala      |  45 ++++----
 .../catalyst/plans/logical/basicOperators.scala |  32 ++++--
 .../sql/catalyst/analysis/AnalysisSuite.scala   |   5 +
 .../optimizer/AggregateOptimizeSuite.scala      |  12 --
 .../optimizer/ReplaceOperatorSuite.scala        |  59 ++++++++++
 .../catalyst/optimizer/SetOperationSuite.scala  |  15 +--
 .../spark/sql/execution/SparkStrategies.scala   |   5 +-
 .../spark/sql/execution/basicOperators.scala    |  12 --
 .../org/apache/spark/sql/DataFrameSuite.scala   |  21 ++++
 11 files changed, 211 insertions(+), 122 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 33d76ee..5fe700e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -344,6 +344,63 @@ class Analyzer(
       }
     }
 
+    /**
+     * Generate a new logical plan for the right child with different expression IDs
+     * for all conflicting attributes.
+     */
+    private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
+      val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+      logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
+        s"between $left and $right")
+
+      right.collect {
+        // Handle base relations that might appear more than once.
+        case oldVersion: MultiInstanceRelation
+            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+          val newVersion = oldVersion.newInstance()
+          (oldVersion, newVersion)
+
+        // Handle projects that create conflicting aliases.
+        case oldVersion @ Project(projectList, _)
+            if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
+          (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
+
+        case oldVersion @ Aggregate(_, aggregateExpressions, _)
+            if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
+          (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
+
+        case oldVersion: Generate
+            if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+          val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+          (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
+        case oldVersion @ Window(_, windowExpressions, _, _, child)
+            if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
+              .nonEmpty =>
+          (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
+      }
+        // Only handle first case, others will be fixed on the next pass.
+        .headOption match {
+        case None =>
+          /*
+           * No result implies that there is a logical plan node that produces new references
+           * that this rule cannot handle. When that is the case, there must be another rule
+           * that resolves these conflicts. Otherwise, the analysis will fail.
+           */
+          right
+        case Some((oldRelation, newRelation)) =>
+          val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
+          val newRight = right transformUp {
+            case r if r == oldRelation => newRelation
+          } transformUp {
+            case other => other transformExpressions {
+              case a: Attribute => attributeRewrites.get(a).getOrElse(a)
+            }
+          }
+          newRight
+      }
+    }
+
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case p: LogicalPlan if !p.childrenResolved => p
 
@@ -388,57 +445,11 @@ class Analyzer(
             .map(_.asInstanceOf[NamedExpression])
         a.copy(aggregateExpressions = expanded)
 
-      // Special handling for cases when self-join introduce duplicate expression ids.
-      case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
-        val conflictingAttributes = left.outputSet.intersect(right.outputSet)
-        logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
-
-        right.collect {
-          // Handle base relations that might appear more than once.
-          case oldVersion: MultiInstanceRelation
-              if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-            val newVersion = oldVersion.newInstance()
-            (oldVersion, newVersion)
-
-          // Handle projects that create conflicting aliases.
-          case oldVersion @ Project(projectList, _)
-              if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
-            (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
-
-          case oldVersion @ Aggregate(_, aggregateExpressions, _)
-              if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
-            (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
-
-          case oldVersion: Generate
-              if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
-            val newOutput = oldVersion.generatorOutput.map(_.newInstance())
-            (oldVersion, oldVersion.copy(generatorOutput = newOutput))
-
-          case oldVersion @ Window(_, windowExpressions, _, _, child)
-              if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
-                .nonEmpty =>
-            (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
-        }
-        // Only handle first case, others will be fixed on the next pass.
-        .headOption match {
-          case None =>
-            /*
-             * No result implies that there is a logical plan node that produces new references
-             * that this rule cannot handle. When that is the case, there must be another rule
-             * that resolves these conflicts. Otherwise, the analysis will fail.
-             */
-            j
-          case Some((oldRelation, newRelation)) =>
-            val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
-            val newRight = right transformUp {
-              case r if r == oldRelation => newRelation
-            } transformUp {
-              case other => other transformExpressions {
-                case a: Attribute => attributeRewrites.get(a).getOrElse(a)
-              }
-            }
-            j.copy(right = newRight)
-        }
+      // To resolve duplicate expression IDs for Join and Intersect
+      case j @ Join(left, right, _, _) if !j.duplicateResolved =>
+        j.copy(right = dedupRight(left, right))
+      case i @ Intersect(left, right) if !i.duplicateResolved =>
+        i.copy(right = dedupRight(left, right))
 
       // When resolve `SortOrder`s in Sort based on child, don't report errors as
       // we still have chance to resolve it based on grandchild

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f2e78d9..4a2f2b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -214,9 +214,8 @@ trait CheckAnalysis {
               s"""Only a single table generating function is allowed in a SELECT clause, found:
                  | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
 
-          // Special handling for cases when self-join introduce duplicate expression ids.
-          case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
-            val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+          case j: Join if !j.duplicateResolved =>
+            val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
             failAnalysis(
               s"""
                  |Failure when resolving conflicting references in Join:
@@ -224,6 +223,15 @@ trait CheckAnalysis {
                  |Conflicting attributes: ${conflictingAttributes.mkString(",")}
                  |""".stripMargin)
 
+          case i: Intersect if !i.duplicateResolved =>
+            val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
+            failAnalysis(
+              s"""
+                 |Failure when resolving conflicting references in Intersect:
+                 |$plan
+                 |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+                 |""".stripMargin)
+
           case o if !o.resolved =>
             failAnalysis(
               s"unresolved operator ${operator.simpleString}")

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6addc20..f156b5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
     //   since the other rules might make two separate Unions operators adjacent.
     Batch("Union", Once,
       CombineUnions) ::
+    Batch("Replace Operators", FixedPoint(100),
+      ReplaceIntersectWithSemiJoin,
+      ReplaceDistinctWithAggregate) ::
     Batch("Aggregate", FixedPoint(100),
-      ReplaceDistinctWithAggregate,
       RemoveLiteralFromGroupExpressions) ::
     Batch("Operator Optimizations", FixedPoint(100),
       // Operator push down
@@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] {
 }
 
 /**
- * Pushes certain operations to both sides of a Union, Intersect or Except operator.
+ * Pushes certain operations to both sides of a Union or Except operator.
  * Operations that are safe to pushdown are listed as follows.
  * Union:
  * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
  * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
  * we will not be able to pushdown Projections.
  *
- * Intersect:
- * It is not safe to pushdown Projections through it because we need to get the
- * intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * with deterministic condition.
- *
  * Except:
  * It is not safe to pushdown Projections through it because we need to get the
  * intersect of rows by comparing the entire rows. It is fine to pushdown Filters
@@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
   /**
    * Rewrites an expression so that it can be pushed to the right side of a
-   * Union, Intersect or Except operator. This method relies on the fact that the output attributes
+   * Union or Except operator. This method relies on the fact that the output attributes
    * of a union/intersect/except are always equal to the left child's output.
    */
   private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
@@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
       }
       Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
 
-    // Push down filter through INTERSECT
-    case Filter(condition, Intersect(left, right)) =>
-      val (deterministic, nondeterministic) = partitionByDeterministic(condition)
-      val rewrites = buildRewrites(left, right)
-      Filter(nondeterministic,
-        Intersect(
-          Filter(deterministic, left),
-          Filter(pushToRight(deterministic, rewrites), right)
-        )
-      )
-
     // Push down filter through EXCEPT
     case Filter(condition, Except(left, right)) =>
       val (deterministic, nondeterministic) = partitionByDeterministic(condition)
@@ -1055,6 +1041,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
 }
 
 /**
+ * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
+ * {{{
+ *   SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
+ *   ==>  SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
+ * }}}
+ *
+ * Note:
+ * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL.
+ * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
+ *    join conditions will be incorrect.
+ */
+object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case Intersect(left, right) =>
+      assert(left.output.size == right.output.size)
+      val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
+      Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
+  }
+}
+
+/**
  * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
  * but only makes the grouping key bigger.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e9c970c..16f4b35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
@@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 }
 
-abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
-  final override lazy val resolved: Boolean =
-    childrenResolved &&
-      left.output.length == right.output.length &&
-      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}
+abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode
 
 private[sql] object SetOperation {
   def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
@@ -103,15 +99,30 @@ private[sql] object SetOperation {
 
 case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
 
+  def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
   override def output: Seq[Attribute] =
     left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
       leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
     }
+
+  // Intersect are only resolved if they don't introduce ambiguous expression ids,
+  // since the Optimizer will convert Intersect to Join.
+  override lazy val resolved: Boolean =
+    childrenResolved &&
+      left.output.length == right.output.length &&
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
+      duplicateResolved
 }
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
   /** We don't use right.output because those rows get excluded from the set. */
   override def output: Seq[Attribute] = left.output
+
+  override lazy val resolved: Boolean =
+    childrenResolved &&
+      left.output.length == right.output.length &&
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
 }
 
 /** Factory for constructing new `Union` nodes. */
@@ -169,13 +180,13 @@ case class Join(
     }
   }
 
-  def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+  def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
 
   // Joins are only resolved if they don't introduce ambiguous expression ids.
   override lazy val resolved: Boolean = {
     childrenResolved &&
       expressions.forall(_.resolved) &&
-      selfJoinResolved &&
+      duplicateResolved &&
       condition.forall(_.dataType == BooleanType)
   }
 }
@@ -249,7 +260,7 @@ case class Range(
     end: Long,
     step: Long,
     numSlices: Int,
-    output: Seq[Attribute]) extends LeafNode {
+    output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
   require(step != 0, "step cannot be 0")
   val numElements: BigInt = {
     val safeStart = BigInt(start)
@@ -262,6 +273,9 @@ case class Range(
     }
   }
 
+  override def newInstance(): Range =
+    Range(start, end, step, numSlices, output.map(_.newInstance()))
+
   override def statistics: Statistics = {
     val sizeInBytes = LongType.defaultSize * numElements
     Statistics( sizeInBytes = sizeInBytes )

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ab68028..1938bce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest {
     checkAnalysis(plan, expected)
   }
 
+  test("self intersect should resolve duplicate expression IDs") {
+    val plan = testRelation.intersect(testRelation)
+    assertAnalysisSuccess(plan)
+  }
+
   test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
     val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
       LocalRelation()

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index 37148a2..a4a12c0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches = Batch("Aggregate", FixedPoint(100),
-      ReplaceDistinctWithAggregate,
       RemoveLiteralFromGroupExpressions) :: Nil
   }
 
-  test("replace distinct with aggregate") {
-    val input = LocalRelation('a.int, 'b.int)
-
-    val query = Distinct(input)
-    val optimized = Optimize.execute(query.analyze)
-
-    val correctAnswer = Aggregate(input.output, input.output, input)
-
-    comparePlans(optimized, correctAnswer)
-  }
-
   test("remove literals in grouping expression") {
     val input = LocalRelation('a.int, 'b.int)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
new file mode 100644
index 0000000..f8ae5d9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class ReplaceOperatorSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Replace Operators", FixedPoint(100),
+        ReplaceDistinctWithAggregate,
+        ReplaceIntersectWithSemiJoin) :: Nil
+  }
+
+  test("replace Intersect with Left-semi Join") {
+    val table1 = LocalRelation('a.int, 'b.int)
+    val table2 = LocalRelation('c.int, 'd.int)
+
+    val query = Intersect(table1, table2)
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer =
+      Aggregate(table1.output, table1.output,
+        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("replace Distinct with Aggregate") {
+    val input = LocalRelation('a.int, 'b.int)
+
+    val query = Distinct(input)
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer = Aggregate(input.output, input.output, input)
+
+    comparePlans(optimized, correctAnswer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index 2283f7c..b8ea32b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest {
   val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
   val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
   val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
-  val testIntersect = Intersect(testRelation, testRelation2)
   val testExcept = Except(testRelation, testRelation2)
 
   test("union: combine unions into one unions") {
@@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest {
     comparePlans(combinedUnionsOptimized, unionOptimized3)
   }
 
-  test("intersect/except: filter to each side") {
-    val intersectQuery = testIntersect.where('b < 10)
+  test("except: filter to each side") {
     val exceptQuery = testExcept.where('c >= 5)
-
-    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
     val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
-    val intersectCorrectAnswer =
-      Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
     val exceptCorrectAnswer =
       Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
 
-    comparePlans(intersectOptimized, intersectCorrectAnswer)
     comparePlans(exceptOptimized, exceptCorrectAnswer)
   }
 
@@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest {
   }
 
   test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
-    val intersectQuery = testIntersect.select('b, 'c)
     val exceptQuery = testExcept.select('a, 'b, 'c)
-
-    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
     val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
-    comparePlans(intersectOptimized, intersectQuery.analyze)
     comparePlans(exceptOptimized, exceptQuery.analyze)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60fbb59..9293e55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.Distinct(child) =>
         throw new IllegalStateException(
           "logical distinct operator should have been replaced by aggregate in the optimizer")
+      case logical.Intersect(left, right) =>
+        throw new IllegalStateException(
+          "logical intersect operator should have been replaced by semi-join in the optimizer")
 
       case logical.MapPartitions(f, in, out, child) =>
         execution.MapPartitions(f, in, out, planLater(child)) :: Nil
@@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Union(unionChildren.map(planLater)) :: Nil
       case logical.Except(left, right) =>
         execution.Except(planLater(left), planLater(right)) :: Nil
-      case logical.Intersect(left, right) =>
-        execution.Intersect(planLater(left), planLater(right)) :: Nil
       case g @ logical.Generate(generator, join, outer, _, _, child) =>
         execution.Generate(
           generator, join = join, outer = outer, g.output, planLater(child)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e7a73d5..fd81531 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -421,18 +421,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
 }
 
 /**
- * Returns the rows in left that also appear in right using the built in spark
- * intersection function.
- */
-case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
-  override def output: Seq[Attribute] = children.head.output
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    left.execute().map(_.copy()).intersection(right.execute().map(_.copy()))
-  }
-}
-
-/**
  * A plan node that does nothing but lie about the output of its child.  Used to spice a
  * (hopefully structurally equivalent) tree from a different optimization sequence into an already
  * resolved tree.

http://git-wip-us.apache.org/repos/asf/spark/blob/5f686cc8/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 09bbe57..4ff99bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Row(3, "c") ::
       Row(4, "d") :: Nil)
     checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
+
+    // check null equality
+    checkAnswer(
+      nullInts.intersect(nullInts),
+      Row(1) ::
+      Row(2) ::
+      Row(3) ::
+      Row(null) :: Nil)
+
+    // check if values are de-duplicated
+    checkAnswer(
+      allNulls.intersect(allNulls),
+      Row(null) :: Nil)
+
+    // check if values are de-duplicated
+    val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value")
+    checkAnswer(
+      df.intersect(df),
+      Row("id1", 1) ::
+      Row("id", 1) ::
+      Row("id1", 2) :: Nil)
   }
 
   test("intersect - nullability") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org