You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/07/30 05:11:06 UTC
spark git commit: [SPARK-21274][SQL] Implement INTERSECT ALL clause
Repository: spark
Updated Branches:
refs/heads/master 6690924c4 -> 65a4bc143
[SPARK-21274][SQL] Implement INTERSECT ALL clause
## What changes were proposed in this pull request?
Implements INTERSECT ALL clause through query rewrites using existing operators in Spark. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design.
Input Query
``` SQL
SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
```
Rewritten Query
```SQL
SELECT c1
FROM (
SELECT replicate_row(min_count, c1)
FROM (
SELECT c1,
IF (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count
FROM (
SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt
FROM (
SELECT c1, true as vcol1, null as vcol2 FROM ut1
UNION ALL
SELECT c1, null as vcol1, true as vcol2 FROM ut2
) AS union_all
GROUP BY c1
HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1
)
)
)
```
## How was this patch tested?
Added test cases in SQLQueryTestSuite, DataFrameSuite, SetOperationSuite
Author: Dilip Biswal <db...@us.ibm.com>
Closes #21886 from dilipbiswal/dkb_intersect_all_final.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/65a4bc14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/65a4bc14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/65a4bc14
Branch: refs/heads/master
Commit: 65a4bc143ab5dc2ced589dc107bbafa8a7290931
Parents: 6690924
Author: Dilip Biswal <db...@us.ibm.com>
Authored: Sun Jul 29 22:11:01 2018 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Sun Jul 29 22:11:01 2018 -0700
----------------------------------------------------------------------
python/pyspark/sql/dataframe.py | 22 ++
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
.../sql/catalyst/analysis/TypeCoercion.scala | 4 +-
.../analysis/UnsupportedOperationChecker.scala | 2 +-
.../sql/catalyst/optimizer/Optimizer.scala | 81 ++++++-
.../spark/sql/catalyst/parser/AstBuilder.scala | 2 +-
.../plans/logical/basicLogicalOperators.scala | 7 +-
.../catalyst/optimizer/SetOperationSuite.scala | 32 ++-
.../sql/catalyst/parser/PlanParserSuite.scala | 1 -
.../scala/org/apache/spark/sql/Dataset.scala | 19 +-
.../spark/sql/execution/SparkStrategies.scala | 8 +-
.../sql-tests/inputs/intersect-all.sql | 123 ++++++++++
.../sql-tests/results/intersect-all.sql.out | 241 +++++++++++++++++++
.../org/apache/spark/sql/DataFrameSuite.scala | 54 +++++
.../org/apache/spark/sql/test/SQLTestData.scala | 13 +
15 files changed, 599 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b2e0a5b..07fb260 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1500,6 +1500,28 @@ class DataFrame(object):
"""
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+ @since(2.4)
+ def intersectAll(self, other):
+ """ Return a new :class:`DataFrame` containing rows in both this dataframe and other
+ dataframe while preserving duplicates.
+
+ This is equivalent to `INTERSECT ALL` in SQL.
+ >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"])
+ >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])
+
+ >>> df1.intersectAll(df2).sort("C1", "C2").show()
+ +---+---+
+ | C1| C2|
+ +---+---+
+ | a| 1|
+ | a| 1|
+ | b| 3|
+ +---+---+
+
+ Also as standard in SQL, this function resolves columns by position (not by name).
+ """
+ return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
+
@since(1.3)
def subtract(self, other):
""" Return a new :class:`DataFrame` containing rows in this frame
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8abb1c7..9965cd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -914,7 +914,7 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
- case i @ Intersect(left, right) if !i.duplicateResolved =>
+ case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
e.copy(right = dedupRight(left, right))
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index f9edca5..7dd26b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -325,11 +325,11 @@ object TypeCoercion {
assert(newChildren.length == 2)
Except(newChildren.head, newChildren.last, isAll)
- case s @ Intersect(left, right) if s.childrenResolved &&
+ case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
- Intersect(newChildren.head, newChildren.last)
+ Intersect(newChildren.head, newChildren.last, isAll)
case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index c9a3ee4..cff4cee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -309,7 +309,7 @@ object UnsupportedOperationChecker {
case Except(left, right, _) if right.isStreaming =>
throwError("Except on a streaming DataFrame/Dataset on the right is not supported")
- case Intersect(left, right) if left.isStreaming && right.isStreaming =>
+ case Intersect(left, right, _) if left.isStreaming && right.isStreaming =>
throwError("Intersect between two streaming DataFrames/Datasets is not supported")
case GroupingSets(_, _, child, _) if child.isStreaming =>
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 193f659..105623c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
OptimizeSubqueries) ::
Batch("Replace Operators", fixedPoint,
RewriteExcepAll,
+ RewriteIntersectAll,
ReplaceIntersectWithSemiJoin,
ReplaceExceptWithFilter,
ReplaceExceptWithAntiJoin,
@@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
*/
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Intersect(left, right) =>
+ case Intersect(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
@@ -1489,6 +1490,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] {
}
/**
+ * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate
+ * and Generate operator.
+ *
+ * Input Query :
+ * {{{
+ * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
+ * }}}
+ *
+ * Rewritten Query:
+ * {{{
+ * SELECT c1
+ * FROM (
+ * SELECT replicate_row(min_count, c1)
+ * FROM (
+ * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count
+ * FROM (
+ * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt
+ * FROM (
+ * SELECT true as vcol1, null as , c1 FROM ut1
+ * UNION ALL
+ * SELECT null as vcol1, true as vcol2, c1 FROM ut2
+ * ) AS union_all
+ * GROUP BY c1
+ * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1
+ * )
+ * )
+ * )
+ * }}}
+ */
+object RewriteIntersectAll extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Intersect(left, right, true) =>
+ assert(left.output.size == right.output.size)
+
+ val trueVcol1 = Alias(Literal(true), "vcol1")()
+ val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")()
+
+ val trueVcol2 = Alias(Literal(true), "vcol2")()
+ val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")()
+
+ // Add a projection on the top of left and right plans to project out
+ // the additional virtual columns.
+ val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left)
+ val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right)
+
+ val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols)
+
+ // Expressions to compute count and minimum of both the counts.
+ val vCol1AggrExpr =
+ Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")()
+ val vCol2AggrExpr =
+ Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")()
+ val ifExpression = Alias(If(
+ GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute),
+ vCol2AggrExpr.toAttribute,
+ vCol1AggrExpr.toAttribute
+ ), "min_count")()
+
+ val aggregatePlan = Aggregate(left.output,
+ Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan)
+ val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)),
+ GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan)
+ val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan)
+
+ // Apply the replicator to replicate rows based on min_count
+ val genRowPlan = Generate(
+ ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output),
+ unrequiredChildIndex = Nil,
+ outer = false,
+ qualifier = None,
+ left.output,
+ projectMinPlan
+ )
+ Project(left.output, genRowPlan)
+ }
+}
+
+/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8b3c068..8a8db6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case SqlBaseParser.UNION =>
Distinct(Union(left, right))
case SqlBaseParser.INTERSECT if all =>
- throw new ParseException("INTERSECT ALL is not supported.", ctx)
+ Intersect(left, right, isAll = true)
case SqlBaseParser.INTERSECT =>
Intersect(left, right)
case SqlBaseParser.EXCEPT if all =>
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 498a13a..13b5130 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -164,7 +164,12 @@ object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}
-case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+case class Intersect(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ isAll: Boolean = false) extends SetOperation(left, right) {
+
+ override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )
override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index f002aa3..cb744be 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.BooleanType
class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest {
))
comparePlans(expectedPlan, rewrittenPlan)
}
+
+ test("INTERSECT ALL rewrite") {
+ val input = Intersect(testRelation, testRelation2, isAll = true)
+ val rewrittenPlan = RewriteIntersectAll(input)
+ val leftRelation = testRelation
+ .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c)
+ val rightRelation = testRelation2
+ .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f)
+ val planFragment = leftRelation.union(rightRelation)
+ .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"),
+ count('vcol2).as("vcol2_count"), 'a, 'b, 'c)
+ .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)),
+ GreaterThanOrEqual('vcol2_count, Literal(1L))))
+ .select('a, 'b, 'c,
+ If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count"))
+ .analyze
+ val multiplerAttr = planFragment.output.last
+ val output = planFragment.output.dropRight(1)
+ val expectedPlan = Project(output,
+ Generate(
+ ReplicateRows(Seq(multiplerAttr) ++ output),
+ Nil,
+ false,
+ None,
+ output,
+ planFragment
+ ))
+ comparePlans(expectedPlan, rewrittenPlan)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 629e3c4..9be0ec5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest {
intercept("select * from a minus all select * from b", "MINUS ALL is not supported.")
assertEqual("select * from a minus distinct select * from b", a.except(b))
assertEqual("select * from a intersect select * from b", a.intersect(b))
- intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e6a3b0a..d36c8d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1935,6 +1935,23 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a new Dataset containing rows only in both this Dataset and another Dataset while
+ * preserving the duplicates.
+ * This is equivalent to `INTERSECT ALL` in SQL.
+ *
+ * @note Equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`. Also as standard
+ * in SQL, this function resolves columns by position (not by name).
+ *
+ * @group typedrel
+ * @since 2.4.0
+ */
+ def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator {
+ Intersect(logicalPlan, other.logicalPlan, isAll = true)
+ }
+
+
+ /**
* Returns a new Dataset containing rows in this Dataset but not in another Dataset.
* This is equivalent to `EXCEPT DISTINCT` in SQL.
*
@@ -1961,7 +1978,7 @@ class Dataset[T] private[sql](
* @since 2.4.0
*/
def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator {
- Except(planWithBarrier, other.planWithBarrier, isAll = true)
+ Except(logicalPlan, other.logicalPlan, isAll = true)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3f5fd3d..75eff8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -529,9 +529,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
- case logical.Intersect(left, right) =>
+ case logical.Intersect(left, right, false) =>
throw new IllegalStateException(
- "logical intersect operator should have been replaced by semi-join in the optimizer")
+ "logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.Intersect(left, right, true) =>
+ throw new IllegalStateException(
+ "logical intersect operator should have been replaced by union, aggregate" +
+ "and generate operators in the optimizer")
case logical.Except(left, right, false) =>
throw new IllegalStateException(
"logical except operator should have been replaced by anti-join in the optimizer")
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
new file mode 100644
index 0000000..ff4395c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
@@ -0,0 +1,123 @@
+CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
+ (1, 2),
+ (1, 2),
+ (1, 3),
+ (1, 3),
+ (2, 3),
+ (null, null),
+ (null, null)
+ AS tab1(k, v);
+CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
+ (1, 2),
+ (1, 2),
+ (2, 3),
+ (3, 4),
+ (null, null),
+ (null, null)
+ AS tab2(k, v);
+
+-- Basic INTERSECT ALL
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2;
+
+-- INTERSECT ALL same table in both branches
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab1 WHERE k = 1;
+
+-- Empty left relation
+SELECT * FROM tab1 WHERE k > 2
+INTERSECT ALL
+SELECT * FROM tab2;
+
+-- Empty right relation
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2 WHERE k > 3;
+
+-- Type Coerced INTERSECT ALL
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT);
+
+-- Error as types of two side are not compatible
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT array(1), 2;
+
+-- Mismatch on number of columns across both branches
+SELECT k FROM tab1
+INTERSECT ALL
+SELECT k, v FROM tab2;
+
+-- Basic
+SELECT * FROM tab2
+INTERSECT ALL
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2;
+
+-- Chain of different `set operations
+-- We need to parenthesize the following two queries to enforce
+-- certain order of evaluation of operators. After fix to
+-- SPARK-24966 this can be removed.
+SELECT * FROM tab1
+EXCEPT
+SELECT * FROM tab2
+UNION ALL
+(
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+);
+
+-- Chain of different `set operations
+SELECT * FROM tab1
+EXCEPT
+SELECT * FROM tab2
+EXCEPT
+(
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+);
+
+-- Join under intersect all
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+INTERSECT ALL
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k);
+
+-- Join under intersect all (2)
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+INTERSECT ALL
+SELECT *
+FROM (SELECT tab2.v AS k,
+ tab1.k AS v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k);
+
+-- Group by under intersect all
+SELECT v FROM tab1 GROUP BY v
+INTERSECT ALL
+SELECT k FROM tab2 GROUP BY k;
+
+-- Clean-up
+DROP VIEW IF EXISTS tab1;
+DROP VIEW IF EXISTS tab2;
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
new file mode 100644
index 0000000..792791b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
@@ -0,0 +1,241 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 17
+
+
+-- !query 0
+CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
+ (1, 2),
+ (1, 2),
+ (1, 3),
+ (1, 3),
+ (2, 3),
+ (null, null),
+ (null, null)
+ AS tab1(k, v)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
+ (1, 2),
+ (1, 2),
+ (2, 3),
+ (3, 4),
+ (null, null),
+ (null, null)
+ AS tab2(k, v)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+-- !query 2 schema
+struct<k:int,v:int>
+-- !query 2 output
+1 2
+1 2
+2 3
+NULL NULL
+NULL NULL
+
+
+-- !query 3
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab1 WHERE k = 1
+-- !query 3 schema
+struct<k:int,v:int>
+-- !query 3 output
+1 2
+1 2
+1 3
+1 3
+
+
+-- !query 4
+SELECT * FROM tab1 WHERE k > 2
+INTERSECT ALL
+SELECT * FROM tab2
+-- !query 4 schema
+struct<k:int,v:int>
+-- !query 4 output
+
+
+
+-- !query 5
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2 WHERE k > 3
+-- !query 5 schema
+struct<k:int,v:int>
+-- !query 5 output
+
+
+
+-- !query 6
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT)
+-- !query 6 schema
+struct<k:bigint,v:bigint>
+-- !query 6 output
+1 2
+
+
+-- !query 7
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT array(1), 2
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+IntersectAll can only be performed on tables with the compatible column types. array<int> <> int at the first column of the second table;
+
+
+-- !query 8
+SELECT k FROM tab1
+INTERSECT ALL
+SELECT k, v FROM tab2
+-- !query 8 schema
+struct<>
+-- !query 8 output
+org.apache.spark.sql.AnalysisException
+IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns;
+
+
+-- !query 9
+SELECT * FROM tab2
+INTERSECT ALL
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+-- !query 9 schema
+struct<k:int,v:int>
+-- !query 9 output
+1 2
+1 2
+2 3
+NULL NULL
+NULL NULL
+
+
+-- !query 10
+SELECT * FROM tab1
+EXCEPT
+SELECT * FROM tab2
+UNION ALL
+(
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+)
+-- !query 10 schema
+struct<k:int,v:int>
+-- !query 10 output
+1 2
+1 2
+1 3
+2 3
+NULL NULL
+NULL NULL
+
+
+-- !query 11
+SELECT * FROM tab1
+EXCEPT
+SELECT * FROM tab2
+EXCEPT
+(
+SELECT * FROM tab1
+INTERSECT ALL
+SELECT * FROM tab2
+)
+-- !query 11 schema
+struct<k:int,v:int>
+-- !query 11 output
+1 3
+
+
+-- !query 12
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+INTERSECT ALL
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+-- !query 12 schema
+struct<k:int,v:int>
+-- !query 12 output
+1 2
+1 2
+1 2
+1 2
+1 2
+1 2
+1 2
+1 2
+2 3
+
+
+-- !query 13
+SELECT *
+FROM (SELECT tab1.k,
+ tab2.v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+INTERSECT ALL
+SELECT *
+FROM (SELECT tab2.v AS k,
+ tab1.k AS v
+ FROM tab1
+ JOIN tab2
+ ON tab1.k = tab2.k)
+-- !query 13 schema
+struct<k:int,v:int>
+-- !query 13 output
+
+
+
+-- !query 14
+SELECT v FROM tab1 GROUP BY v
+INTERSECT ALL
+SELECT k FROM tab2 GROUP BY k
+-- !query 14 schema
+struct<v:int>
+-- !query 14 output
+2
+3
+NULL
+
+
+-- !query 15
+DROP VIEW IF EXISTS tab1
+-- !query 15 schema
+struct<>
+-- !query 15 output
+
+
+
+-- !query 16
+DROP VIEW IF EXISTS tab2
+-- !query 16 schema
+struct<>
+-- !query 16 output
+
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index af07359..b0e22a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -749,6 +749,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df4.schema.forall(!_.nullable))
}
+ test("intersectAll") {
+ checkAnswer(
+ lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates),
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(3, "c") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
+ checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil)
+
+ // check null equality
+ checkAnswer(
+ nullInts.intersectAll(nullInts),
+ Row(1) ::
+ Row(2) ::
+ Row(3) ::
+ Row(null) :: Nil)
+
+ // Duplicate nulls are preserved.
+ checkAnswer(
+ allNulls.intersectAll(allNulls),
+ Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil)
+
+ val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id")
+ val df_right = Seq(1, 2, 2, 3).toDF("id")
+
+ checkAnswer(
+ df_left.intersectAll(df_right),
+ Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("intersectAll - nullability") {
+ val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
+ assert(nonNullableInts.schema.forall(!_.nullable))
+
+ val df1 = nonNullableInts.intersectAll(nullInts)
+ checkAnswer(df1, Row(1) :: Row(3) :: Nil)
+ assert(df1.schema.forall(!_.nullable))
+
+ val df2 = nullInts.intersectAll(nonNullableInts)
+ checkAnswer(df2, Row(1) :: Row(3) :: Nil)
+ assert(df2.schema.forall(!_.nullable))
+
+ val df3 = nullInts.intersectAll(nullInts)
+ checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+ assert(df3.schema.forall(_.nullable))
+
+ val df4 = nonNullableInts.intersectAll(nonNullableInts)
+ checkAnswer(df4, Row(1) :: Row(3) :: Nil)
+ assert(df4.schema.forall(!_.nullable))
+ }
+
test("udf") {
val foo = udf((a: Int, b: String) => a.toString + b)
http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 0cfe260..deea9db 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self =>
df
}
+ protected lazy val lowerCaseDataWithDuplicates: DataFrame = {
+ val df = spark.sparkContext.parallelize(
+ LowerCaseData(1, "a") ::
+ LowerCaseData(2, "b") ::
+ LowerCaseData(2, "b") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(4, "d") :: Nil).toDF()
+ df.createOrReplaceTempView("lowerCaseData")
+ df
+ }
+
protected lazy val arrayData: RDD[ArrayData] = {
val rdd = spark.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org