You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ku...@apache.org on 2019/07/05 07:22:46 UTC
[flink] branch master updated: [FLINK-12936][table-planner-blink]
Support "intersect all" and "minus all"
This is an automated email from the ASF dual-hosted git repository.
kurt pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new b4403f2 [FLINK-12936][table-planner-blink] Support "intersect all" and "minus all"
b4403f2 is described below
commit b4403f2657f6fc63298fbd4caa27580e909d3f07
Author: Jingsong Lee <lz...@aliyun.com>
AuthorDate: Fri Jul 5 15:22:35 2019 +0800
[FLINK-12936][table-planner-blink] Support "intersect all" and "minus all"
This closes #8898
---
.../functions/tablefunctions/ReplicateRows.java | 71 ++++++++++
.../table/plan/rules/FlinkBatchRuleSets.scala | 4 +-
.../table/plan/rules/FlinkStreamRuleSets.scala | 4 +-
.../logical/ReplaceIntersectWithSemiJoinRule.scala | 17 ++-
.../logical/ReplaceMinusWithAntiJoinRule.scala | 15 +-
.../logical/ReplaceSetOpWithJoinRuleBase.scala | 58 --------
.../rules/logical/RewriteIntersectAllRule.scala | 143 +++++++++++++++++++
.../plan/rules/logical/RewriteMinusAllRule.scala | 121 ++++++++++++++++
.../flink/table/plan/util/SetOpRewriteUtil.scala | 118 ++++++++++++++++
.../table/plan/batch/sql/SetOperatorsTest.xml | 58 ++++++++
.../rules/logical/RewriteIntersectAllRuleTest.xml | 151 ++++++++++++++++++++
.../plan/rules/logical/RewriteMinusAllRuleTest.xml | 151 ++++++++++++++++++++
.../table/plan/stream/sql/SetOperatorsTest.xml | 56 ++++++++
.../table/plan/batch/sql/SetOperatorsTest.scala | 4 +-
.../ReplaceIntersectWithSemiJoinRuleTest.scala | 6 -
.../logical/ReplaceMinusWithAntiJoinRuleTest.scala | 5 -
...est.scala => RewriteIntersectAllRuleTest.scala} | 28 ++--
...uleTest.scala => RewriteMinusAllRuleTest.scala} | 26 ++--
.../table/plan/stream/sql/SetOperatorsTest.scala | 4 +-
.../runtime/batch/sql/SetOperatorsITCase.scala | 154 +++++++++++++++++++++
.../runtime/stream/sql/SetOperatorsITCase.scala | 134 ++++++++++++++++++
21 files changed, 1208 insertions(+), 120 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/tablefunctions/ReplicateRows.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/tablefunctions/ReplicateRows.java
new file mode 100644
index 0000000..b7b4338
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/tablefunctions/ReplicateRows.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.tablefunctions;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.table.functions.TableFunction;
+import org.apache.flink.types.Row;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+
+/**
+ * Replicate the row N times. N is specified as the first argument to the function.
+ * This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND
+ * INTERSECT ALL queries.
+ */
+public class ReplicateRows extends TableFunction<Row> {
+
+ private static final long serialVersionUID = 1L;
+
+ private final TypeInformation[] fieldTypes;
+ private transient Row reuseRow;
+
+ public ReplicateRows(TypeInformation[] fieldTypes) {
+ this.fieldTypes = fieldTypes;
+ }
+
+ public void eval(Object... inputs) {
+ checkArgument(inputs.length == fieldTypes.length + 1);
+ long numRows = (long) inputs[0];
+ if (reuseRow == null) {
+ reuseRow = new Row(fieldTypes.length);
+ }
+ for (int i = 0; i < fieldTypes.length; i++) {
+ reuseRow.setField(i, inputs[i + 1]);
+ }
+ for (int i = 0; i < numRows; i++) {
+ collect(reuseRow);
+ }
+ }
+
+ @Override
+ public TypeInformation<Row> getResultType() {
+ return new RowTypeInfo(fieldTypes);
+ }
+
+ @Override
+ public TypeInformation<?>[] getParameterTypes(Class<?>[] signature) {
+ TypeInformation[] paraTypes = new TypeInformation[1 + fieldTypes.length];
+ paraTypes[0] = Types.LONG;
+ System.arraycopy(fieldTypes, 0, paraTypes, 1, fieldTypes.length);
+ return paraTypes;
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
index 71fcb23..cce4016 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
@@ -303,7 +303,9 @@ object FlinkBatchRuleSets {
// set operators
ReplaceIntersectWithSemiJoinRule.INSTANCE,
- ReplaceMinusWithAntiJoinRule.INSTANCE
+ RewriteIntersectAllRule.INSTANCE,
+ ReplaceMinusWithAntiJoinRule.INSTANCE,
+ RewriteMinusAllRule.INSTANCE
)
/**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index 187354d..35da87e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -276,7 +276,9 @@ object FlinkStreamRuleSets {
// set operators
ReplaceIntersectWithSemiJoinRule.INSTANCE,
- ReplaceMinusWithAntiJoinRule.INSTANCE
+ RewriteIntersectAllRule.INSTANCE,
+ ReplaceMinusWithAntiJoinRule.INSTANCE,
+ RewriteMinusAllRule.INSTANCE
)
/**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala
index 6ac7dea..3dddeb5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala
@@ -18,8 +18,11 @@
package org.apache.flink.table.plan.rules.logical
+import org.apache.flink.table.plan.util.SetOpRewriteUtil.generateEqualsCondition
+
+import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rel.core.{Aggregate, Intersect, Join, JoinRelType}
+import org.apache.calcite.rel.core.{Aggregate, Intersect, Join, JoinRelType, RelFactories}
import scala.collection.JavaConversions._
@@ -27,16 +30,16 @@ import scala.collection.JavaConversions._
* Planner rule that replaces distinct [[Intersect]] with
* a distinct [[Aggregate]] on a SEMI [[Join]].
*
- * <p>Note: Not support Intersect All.
+ * Only handle the case of input size 2.
*/
-class ReplaceIntersectWithSemiJoinRule extends ReplaceSetOpWithJoinRuleBase(
- classOf[Intersect],
+class ReplaceIntersectWithSemiJoinRule extends RelOptRule(
+ operand(classOf[Intersect], any),
+ RelFactories.LOGICAL_BUILDER,
"ReplaceIntersectWithSemiJoinRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val intersect: Intersect = call.rel(0)
- // not support intersect all now.
- intersect.isDistinct
+ !intersect.all && intersect.getInputs.size() == 2
}
override def onMatch(call: RelOptRuleCall): Unit = {
@@ -46,7 +49,7 @@ class ReplaceIntersectWithSemiJoinRule extends ReplaceSetOpWithJoinRuleBase(
val relBuilder = call.builder
val keys = 0 until left.getRowType.getFieldCount
- val conditions = generateCondition(relBuilder, left, right, keys)
+ val conditions = generateEqualsCondition(relBuilder, left, right, keys)
relBuilder.push(left)
relBuilder.push(right)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala
index c322cf9..bb2a3f7 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala
@@ -18,6 +18,9 @@
package org.apache.flink.table.plan.rules.logical
+import org.apache.flink.table.plan.util.SetOpRewriteUtil.generateEqualsCondition
+
+import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.core._
@@ -27,16 +30,16 @@ import scala.collection.JavaConversions._
* Planner rule that replaces distinct [[Minus]] (SQL keyword: EXCEPT) with
* a distinct [[Aggregate]] on an ANTI [[Join]].
*
- * <p>Note: Not support Minus All.
+ * Only handle the case of input size 2.
*/
-class ReplaceMinusWithAntiJoinRule extends ReplaceSetOpWithJoinRuleBase(
- classOf[Minus],
+class ReplaceMinusWithAntiJoinRule extends RelOptRule(
+ operand(classOf[Minus], any),
+ RelFactories.LOGICAL_BUILDER,
"ReplaceMinusWithAntiJoinRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val minus: Minus = call.rel(0)
- // not support minus all now.
- minus.isDistinct
+ !minus.all && minus.getInputs.size() == 2
}
override def onMatch(call: RelOptRuleCall): Unit = {
@@ -46,7 +49,7 @@ class ReplaceMinusWithAntiJoinRule extends ReplaceSetOpWithJoinRuleBase(
val relBuilder = call.builder
val keys = 0 until left.getRowType.getFieldCount
- val conditions = generateCondition(relBuilder, left, right, keys)
+ val conditions = generateEqualsCondition(relBuilder, left, right, keys)
relBuilder.push(left)
relBuilder.push(right)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala
deleted file mode 100644
index 1f400a0..0000000
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.rules.logical
-
-import org.apache.calcite.plan.RelOptRule.{any, operand}
-import org.apache.calcite.plan.{RelOptRule, RelOptUtil}
-import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{RelFactories, SetOp}
-import org.apache.calcite.rex.RexNode
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.calcite.tools.RelBuilder
-
-/**
- * Base class that replace [[SetOp]] to [[org.apache.calcite.rel.core.Join]].
- */
-abstract class ReplaceSetOpWithJoinRuleBase[T <: SetOp](
- clazz: Class[T],
- description: String)
- extends RelOptRule(
- operand(clazz, any),
- RelFactories.LOGICAL_BUILDER,
- description) {
-
- protected def generateCondition(
- relBuilder: RelBuilder,
- left: RelNode,
- right: RelNode,
- keys: Seq[Int]): Seq[RexNode] = {
- val rexBuilder = relBuilder.getRexBuilder
- val leftTypes = RelOptUtil.getFieldTypeList(left.getRowType)
- val rightTypes = RelOptUtil.getFieldTypeList(right.getRowType)
- val conditions = keys.map { key =>
- val leftRex = rexBuilder.makeInputRef(leftTypes.get(key), key)
- val rightRex = rexBuilder.makeInputRef(rightTypes.get(key), leftTypes.size + key)
- val equalCond = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftRex, rightRex)
- relBuilder.or(
- equalCond,
- relBuilder.and(relBuilder.isNull(leftRex), relBuilder.isNull(rightRex)))
- }
- conditions
- }
-}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRule.scala
new file mode 100644
index 0000000..60da40d
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRule.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable.{GREATER_THAN, GREATER_THAN_OR_EQUAL, IF}
+import org.apache.flink.table.plan.util.SetOpRewriteUtil.replicateRows
+
+import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.core.{Intersect, RelFactories}
+import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.calcite.util.Util
+
+import scala.collection.JavaConversions._
+
+/**
+ * Replaces logical [[Intersect]] operator using a combination of union all, aggregate
+ * and table function.
+ *
+ * Original Query :
+ * {{{
+ * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
+ * }}}
+ *
+ * Rewritten Query:
+ * {{{
+ * SELECT c1
+ * FROM (
+ * SELECT c1, If (vcol_left_cnt > vcol_right_cnt, vcol_right_cnt, vcol_left_cnt) AS min_count
+ * FROM (
+ * SELECT
+ * c1,
+ * count(vcol_left_marker) as vcol_left_cnt,
+ * count(vcol_right_marker) as vcol_right_cnt
+ * FROM (
+ * SELECT c1, true as vcol_left_marker, null as vcol_right_marker FROM ut1
+ * UNION ALL
+ * SELECT c1, null as vcol_left_marker, true as vcol_right_marker FROM ut2
+ * ) AS union_all
+ * GROUP BY c1
+ * )
+ * WHERE vcol_left_cnt >= 1 AND vcol_right_cnt >= 1
+ * )
+ * )
+ * LATERAL TABLE(replicate_row(min_count, c1)) AS T(c1)
+ * }}}
+ *
+ * Only handle the case of input size 2.
+ */
+class RewriteIntersectAllRule extends RelOptRule(
+ operand(classOf[Intersect], any),
+ RelFactories.LOGICAL_BUILDER,
+ "RewriteIntersectAllRule") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val intersect: Intersect = call.rel(0)
+ intersect.all && intersect.getInputs.size() == 2
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val intersect: Intersect = call.rel(0)
+ val left = intersect.getInput(0)
+ val right = intersect.getInput(1)
+
+ val fields = Util.range(intersect.getRowType.getFieldCount)
+
+ // 1. add marker to left rel node
+ val leftBuilder = call.builder
+ val boolType = leftBuilder.getTypeFactory.createSqlType(SqlTypeName.BOOLEAN)
+ val leftWithMarker = leftBuilder
+ .push(left)
+ .project(
+ leftBuilder.fields(fields) ++ Seq(
+ leftBuilder.alias(leftBuilder.literal(true), "vcol_left_marker"),
+ leftBuilder.alias(
+ leftBuilder.getRexBuilder.makeNullLiteral(boolType), "vcol_right_marker")))
+ .build()
+
+ // 2. add marker to right rel node
+ val rightBuilder = call.builder
+ val rightWithMarker = rightBuilder
+ .push(right)
+ .project(
+ rightBuilder.fields(fields) ++ Seq(
+ rightBuilder.alias(
+ rightBuilder.getRexBuilder.makeNullLiteral(boolType), "vcol_left_marker"),
+ rightBuilder.alias(rightBuilder.literal(true), "vcol_right_marker")))
+ .build()
+
+ // 3. union and aggregate
+ val builder = call.builder
+ builder
+ .push(leftWithMarker)
+ .push(rightWithMarker)
+ .union(true)
+ .aggregate(
+ builder.groupKey(builder.fields(fields)),
+ builder.count(false, "vcol_left_cnt", builder.field("vcol_left_marker")),
+ builder.count(false, "vcol_right_cnt", builder.field("vcol_right_marker")))
+ .filter(builder.and(
+ builder.call(
+ GREATER_THAN_OR_EQUAL,
+ builder.field("vcol_left_cnt"),
+ builder.literal(1)),
+ builder.call(
+ GREATER_THAN_OR_EQUAL,
+ builder.field("vcol_right_cnt"),
+ builder.literal(1))))
+ .project(Seq(builder.call(
+ IF,
+ builder.call(
+ GREATER_THAN,
+ builder.field("vcol_left_cnt"),
+ builder.field("vcol_right_cnt")),
+ builder.field("vcol_right_cnt"),
+ builder.field("vcol_left_cnt"))) ++ builder.fields(fields))
+
+ // 4. add table function to replicate rows
+ val output = replicateRows(builder, intersect.getRowType, fields)
+
+ call.transformTo(output)
+ }
+}
+
+object RewriteIntersectAllRule {
+ val INSTANCE: RelOptRule = new RewriteIntersectAllRule
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRule.scala
new file mode 100644
index 0000000..272837a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRule.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable.GREATER_THAN
+import org.apache.flink.table.plan.util.SetOpRewriteUtil.replicateRows
+
+import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.core.{Minus, RelFactories}
+import org.apache.calcite.sql.`type`.SqlTypeName.BIGINT
+import org.apache.calcite.util.Util
+
+import scala.collection.JavaConversions._
+
+/**
+ * Replaces logical [[Minus]] operator using a combination of union all, aggregate
+ * and table function.
+ *
+ * Original Query :
+ * {{{
+ * SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2
+ * }}}
+ *
+ * Rewritten Query:
+ * {{{
+ * SELECT c1
+ * FROM (
+ * SELECT c1, sum_val
+ * FROM (
+ * SELECT c1, sum(vcol_marker) AS sum_val
+ * FROM (
+ * SELECT c1, 1L as vcol_marker FROM ut1
+ * UNION ALL
+ * SELECT c1, -1L as vcol_marker FROM ut2
+ * ) AS union_all
+ * GROUP BY union_all.c1
+ * )
+ * WHERE sum_val > 0
+ * )
+ * LATERAL TABLE(replicate_row(sum_val, c1)) AS T(c1)
+ * }}}
+ *
+ * Only handle the case of input size 2.
+ */
+class RewriteMinusAllRule extends RelOptRule(
+ operand(classOf[Minus], any),
+ RelFactories.LOGICAL_BUILDER,
+ "RewriteMinusAllRule") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val minus: Minus = call.rel(0)
+ minus.all && minus.getInputs.size() == 2
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val minus: Minus = call.rel(0)
+ val left = minus.getInput(0)
+ val right = minus.getInput(1)
+
+ val fields = Util.range(minus.getRowType.getFieldCount)
+
+ // 1. add vcol_marker to left rel node
+ val leftBuilder = call.builder
+ val leftWithAddedVirtualCols = leftBuilder
+ .push(left)
+ .project(leftBuilder.fields(fields) ++
+ Seq(leftBuilder.alias(
+ leftBuilder.cast(leftBuilder.literal(1L), BIGINT), "vcol_marker")))
+ .build()
+
+ // 2. add vcol_marker to right rel node
+ val rightBuilder = call.builder
+ val rightWithAddedVirtualCols = rightBuilder
+ .push(right)
+ .project(rightBuilder.fields(fields) ++
+ Seq(rightBuilder.alias(
+ leftBuilder.cast(leftBuilder.literal(-1L), BIGINT), "vcol_marker")))
+ .build()
+
+ // 3. add union all and aggregate
+ val builder = call.builder
+ builder
+ .push(leftWithAddedVirtualCols)
+ .push(rightWithAddedVirtualCols)
+ .union(true)
+ .aggregate(
+ builder.groupKey(builder.fields(fields)),
+ builder.sum(false, "sum_vcol_marker", builder.field("vcol_marker")))
+ .filter(builder.call(
+ GREATER_THAN,
+ builder.field("sum_vcol_marker"),
+ builder.literal(0)))
+ .project(Seq(builder.field("sum_vcol_marker")) ++ builder.fields(fields))
+
+ // 4. add table function to replicate rows
+ val output = replicateRows(builder, minus.getRowType, fields)
+
+ call.transformTo(output)
+ }
+}
+
+object RewriteMinusAllRule {
+ val INSTANCE: RelOptRule = new RewriteMinusAllRule
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SetOpRewriteUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SetOpRewriteUtil.scala
new file mode 100644
index 0000000..5030694
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SetOpRewriteUtil.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.calcite.FlinkTypeFactory.toLogicalRowType
+import org.apache.flink.table.functions.tablefunctions.ReplicateRows
+import org.apache.flink.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
+import org.apache.flink.table.plan.schema.TypedFlinkTableFunction
+import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.fromLogicalTypeToTypeInfo
+import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType
+
+import org.apache.calcite.plan.RelOptUtil
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{JoinRelType, SetOp}
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.calcite.tools.RelBuilder
+import org.apache.calcite.util.Util
+
+import java.util
+
+import scala.collection.JavaConversions._
+
+/**
+ * Util class that rewrite [[SetOp]].
+ */
+object SetOpRewriteUtil {
+
+ /**
+ * Generate equals condition by keys (The index on both sides is the same) to
+ * join left relNode and right relNode.
+ */
+ def generateEqualsCondition(
+ relBuilder: RelBuilder,
+ left: RelNode,
+ right: RelNode,
+ keys: Seq[Int]): Seq[RexNode] = {
+ val rexBuilder = relBuilder.getRexBuilder
+ val leftTypes = RelOptUtil.getFieldTypeList(left.getRowType)
+ val rightTypes = RelOptUtil.getFieldTypeList(right.getRowType)
+ val conditions = keys.map { key =>
+ val leftRex = rexBuilder.makeInputRef(leftTypes.get(key), key)
+ val rightRex = rexBuilder.makeInputRef(rightTypes.get(key), leftTypes.size + key)
+ val equalCond = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftRex, rightRex)
+ relBuilder.or(
+ equalCond,
+ relBuilder.and(relBuilder.isNull(leftRex), relBuilder.isNull(rightRex)))
+ }
+ conditions
+ }
+
+ /**
+ * Use table function to replicate the row N times. First field is long type,
+ * and the rest are the row fields.
+ */
+ def replicateRows(
+ builder: RelBuilder, outputType: RelDataType, fields: util.List[Integer]): RelNode = {
+ // construct LogicalTableFunctionScan
+ val logicalType = toLogicalRowType(outputType)
+ val fieldNames = outputType.getFieldNames.toSeq.toArray
+ val fieldTypes = logicalType.getChildren.map(fromLogicalTypeToTypeInfo).toArray
+ val tf = new ReplicateRows(fieldTypes)
+ val resultType = fromLegacyInfoToDataType(new RowTypeInfo(fieldTypes, fieldNames))
+ val function = new TypedFlinkTableFunction(tf, resultType)
+ val typeFactory = builder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val sqlFunction = new TableSqlFunction(
+ tf.functionIdentifier,
+ tf.toString,
+ tf,
+ resultType,
+ typeFactory,
+ function)
+
+ val cluster = builder.peek().getCluster
+ val scan = LogicalTableFunctionScan.create(
+ cluster,
+ new util.ArrayList[RelNode](),
+ builder.call(
+ sqlFunction,
+ builder.fields(Util.range(fields.size() + 1))),
+ function.getElementType(null),
+ UserDefinedFunctionUtils.buildRelDataType(
+ builder.getTypeFactory,
+ logicalType,
+ fieldNames,
+ fieldNames.indices.toArray),
+ null)
+ builder.push(scan)
+
+ // correlated join
+ val corSet = Set(cluster.createCorrel())
+ val output = builder
+ .join(JoinRelType.INNER, builder.literal(true), corSet)
+ .project(builder.fields(Util.range(fields.size() + 1, fields.size() * 2 + 1)))
+ .build()
+ output
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml
index 76cbfc3..a6746ff 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml
@@ -42,6 +42,35 @@ HashAggregate(isMerge=[false], groupBy=[c], select=[c])
]]>
</Resource>
</TestCase>
+ <TestCase name="testIntersectAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c0 AS c])
++- Correlate(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], correlate=[table(ReplicateRows($f0,c))], select=[$f0,c,c0], rowType=[RecordType(BIGINT $f0, VARCHAR(2147483647) c, VARCHAR(2147483647) c0)], joinType=[INNER])
+ +- Calc(select=[IF(>(vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, c], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))])
+ +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_COUNT(count$0) AS vcol_left_cnt, Final_COUNT(count$1) AS vcol_right_cnt])
+ +- Exchange(distribution=[hash[c]])
+ +- LocalHashAggregate(groupBy=[c], select=[c, Partial_COUNT(vcol_left_marker) AS count$0, Partial_COUNT(vcol_right_marker) AS count$1])
+ +- Union(all=[true], union=[c, vcol_left_marker, vcol_right_marker])
+ :- Calc(select=[c, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker])
+ : +- TableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Calc(select=[f, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker])
+ +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testIntersectLeftIsEmpty">
<Resource name="sql">
<![CDATA[SELECT c FROM T1 WHERE 1=0 INTERSECT SELECT f FROM T2]]>
@@ -108,6 +137,35 @@ HashAggregate(isMerge=[false], groupBy=[c], select=[c])
]]>
</Resource>
</TestCase>
+ <TestCase name="testMinusAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalMinus(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c0 AS c])
++- Correlate(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], correlate=[table(ReplicateRows(sum_vcol_marker,c))], select=[sum_vcol_marker,c,c0], rowType=[RecordType(BIGINT sum_vcol_marker, VARCHAR(2147483647) c, VARCHAR(2147483647) c0)], joinType=[INNER])
+ +- Calc(select=[sum_vcol_marker, c], where=[>(sum_vcol_marker, 0)])
+ +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_SUM(sum$0) AS sum_vcol_marker])
+ +- Exchange(distribution=[hash[c]])
+ +- LocalHashAggregate(groupBy=[c], select=[c, Partial_SUM(vcol_marker) AS sum$0])
+ +- Union(all=[true], union=[c, vcol_marker])
+ :- Calc(select=[c, 1:BIGINT AS vcol_marker])
+ : +- TableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Calc(select=[f, -1:BIGINT AS vcol_marker])
+ +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testMinusLeftIsEmpty">
<Resource name="sql">
<![CDATA[SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2]]>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.xml
new file mode 100644
index 0000000..fecf50a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.xml
@@ -0,0 +1,151 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testIntersectAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject($f0=[IF(>($1, $2), $2, $1)], c=[$0])
+ : +- LogicalFilter(condition=[AND(>=($1, 1), >=($2, 1))])
+ : +- LogicalAggregate(group=[{0}], vcol_left_cnt=[COUNT($1)], vcol_right_cnt=[COUNT($2)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_left_marker=[true], vcol_right_marker=[null:BOOLEAN])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_left_marker=[null:BOOLEAN], vcol_right_marker=[true])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testIntersectAllLeftIsEmpty">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 WHERE 1=0 INTERSECT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalFilter(condition=[=(1, 0)])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject($f0=[IF(>($1, $2), $2, $1)], c=[$0])
+ : +- LogicalFilter(condition=[AND(>=($1, 1), >=($2, 1))])
+ : +- LogicalAggregate(group=[{0}], vcol_left_cnt=[COUNT($1)], vcol_right_cnt=[COUNT($2)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_left_marker=[true], vcol_right_marker=[null:BOOLEAN])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalFilter(condition=[=(1, 0)])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_left_marker=[null:BOOLEAN], vcol_right_marker=[true])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testIntersectAllWithFilter">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM ((SELECT * FROM T1) INTERSECT ALL (SELECT * FROM T2)) WHERE a > 1]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[>($0, 1)])
+ +- LogicalIntersect(all=[true])
+ :- LogicalProject(a=[$0], b=[$1], c=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ +- LogicalProject(d=[$0], e=[$1], f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[>($0, 1)])
+ +- LogicalProject(a=[$4], b=[$5], c=[$6])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject($f0=[IF(>($3, $4), $4, $3)], a=[$0], b=[$1], c=[$2])
+ : +- LogicalFilter(condition=[AND(>=($3, 1), >=($4, 1))])
+ : +- LogicalAggregate(group=[{0, 1, 2}], vcol_left_cnt=[COUNT($3)], vcol_right_cnt=[COUNT($4)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(a=[$0], b=[$1], c=[$2], vcol_left_marker=[true], vcol_right_marker=[null:BOOLEAN])
+ : : +- LogicalProject(a=[$0], b=[$1], c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(d=[$0], e=[$1], f=[$2], vcol_left_marker=[null:BOOLEAN], vcol_right_marker=[true])
+ : +- LogicalProject(d=[$0], e=[$1], f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$a265580be75179078c2732913dc90daa($0, $1, $2, $3)], rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testIntersectAllRightIsEmpty">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2 WHERE 1=0]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalFilter(condition=[=(1, 0)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject($f0=[IF(>($1, $2), $2, $1)], c=[$0])
+ : +- LogicalFilter(condition=[AND(>=($1, 1), >=($2, 1))])
+ : +- LogicalAggregate(group=[{0}], vcol_left_cnt=[COUNT($1)], vcol_right_cnt=[COUNT($2)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_left_marker=[true], vcol_right_marker=[null:BOOLEAN])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_left_marker=[null:BOOLEAN], vcol_right_marker=[true])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalFilter(condition=[=(1, 0)])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.xml
new file mode 100644
index 0000000..1cae444
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.xml
@@ -0,0 +1,151 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testExceptAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalMinus(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject(sum_vcol_marker=[$1], c=[$0])
+ : +- LogicalFilter(condition=[>($1, 0)])
+ : +- LogicalAggregate(group=[{0}], sum_vcol_marker=[SUM($1)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_marker=[1:BIGINT])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_marker=[-1:BIGINT])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExceptAllLeftIsEmpty">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 WHERE 1=0 EXCEPT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalMinus(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalFilter(condition=[=(1, 0)])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject(sum_vcol_marker=[$1], c=[$0])
+ : +- LogicalFilter(condition=[>($1, 0)])
+ : +- LogicalAggregate(group=[{0}], sum_vcol_marker=[SUM($1)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_marker=[1:BIGINT])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalFilter(condition=[=(1, 0)])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_marker=[-1:BIGINT])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExceptAllRightIsEmpty">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2 WHERE 1=0]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalMinus(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalFilter(condition=[=(1, 0)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject(sum_vcol_marker=[$1], c=[$0])
+ : +- LogicalFilter(condition=[>($1, 0)])
+ : +- LogicalAggregate(group=[{0}], sum_vcol_marker=[SUM($1)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(c=[$0], vcol_marker=[1:BIGINT])
+ : : +- LogicalProject(c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(f=[$0], vcol_marker=[-1:BIGINT])
+ : +- LogicalProject(f=[$2])
+ : +- LogicalFilter(condition=[=(1, 0)])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], rowType=[RecordType(VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExceptAllWithFilter">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM (SELECT * FROM T1 EXCEPT ALL (SELECT * FROM T2)) WHERE b < 2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[<($1, 2)])
+ +- LogicalMinus(all=[true])
+ :- LogicalProject(a=[$0], b=[$1], c=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ +- LogicalProject(d=[$0], e=[$1], f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[<($1, 2)])
+ +- LogicalProject(a=[$4], b=[$5], c=[$6])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])
+ :- LogicalProject(sum_vcol_marker=[$3], a=[$0], b=[$1], c=[$2])
+ : +- LogicalFilter(condition=[>($3, 0)])
+ : +- LogicalAggregate(group=[{0, 1, 2}], sum_vcol_marker=[SUM($3)])
+ : +- LogicalUnion(all=[true])
+ : :- LogicalProject(a=[$0], b=[$1], c=[$2], vcol_marker=[1:BIGINT])
+ : : +- LogicalProject(a=[$0], b=[$1], c=[$2])
+ : : +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
+ : +- LogicalProject(d=[$0], e=[$1], f=[$2], vcol_marker=[-1:BIGINT])
+ : +- LogicalProject(d=[$0], e=[$1], f=[$2])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+ +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$a265580be75179078c2732913dc90daa($0, $1, $2, $3)], rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml
index 8ccf630..8094d28 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml
@@ -43,6 +43,34 @@ GroupAggregate(groupBy=[c], select=[c])
]]>
</Resource>
</TestCase>
+ <TestCase name="testIntersectAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c0 AS c])
++- Correlate(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], correlate=[table(ReplicateRows($f0,c))], select=[$f0,c,c0], rowType=[RecordType(BIGINT $f0, VARCHAR(2147483647) c, VARCHAR(2147483647) c0)], joinType=[INNER])
+ +- Calc(select=[IF(>(vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, c], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))])
+ +- GroupAggregate(groupBy=[c], select=[c, COUNT(vcol_left_marker) AS vcol_left_cnt, COUNT(vcol_right_marker) AS vcol_right_cnt])
+ +- Exchange(distribution=[hash[c]])
+ +- Union(all=[true], union=[c, vcol_left_marker, vcol_right_marker])
+ :- Calc(select=[c, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker])
+ : +- TableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Calc(select=[f, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker])
+ +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testIntersectLeftIsEmpty">
<Resource name="sql">
<![CDATA[SELECT c FROM T1 WHERE 1=0 INTERSECT SELECT f FROM T2]]>
@@ -110,6 +138,34 @@ GroupAggregate(groupBy=[c], select=[c])
]]>
</Resource>
</TestCase>
+ <TestCase name="testMinusAll">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalMinus(all=[true])
+:- LogicalProject(c=[$2])
+: +- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]])
++- LogicalProject(f=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c0 AS c])
++- Correlate(invocation=[org$apache$flink$table$functions$tablefunctions$ReplicateRows$8a783e32f854e18fdaeaa274e3162b0b($0, $1)], correlate=[table(ReplicateRows(sum_vcol_marker,c))], select=[sum_vcol_marker,c,c0], rowType=[RecordType(BIGINT sum_vcol_marker, VARCHAR(2147483647) c, VARCHAR(2147483647) c0)], joinType=[INNER])
+ +- Calc(select=[sum_vcol_marker, c], where=[>(sum_vcol_marker, 0)])
+ +- GroupAggregate(groupBy=[c], select=[c, SUM(vcol_marker) AS sum_vcol_marker])
+ +- Exchange(distribution=[hash[c]])
+ +- Union(all=[true], union=[c, vcol_marker])
+ :- Calc(select=[c, 1:BIGINT AS vcol_marker])
+ : +- TableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Calc(select=[f, -1:BIGINT AS vcol_marker])
+ +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testMinusLeftIsEmpty">
<Resource name="sql">
<![CDATA[SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2]]>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala
index 635b9b7..f707556 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala
@@ -52,7 +52,7 @@ class SetOperatorsTest extends TableTestBase {
util.verifyPlan("SELECT a, b, c FROM T1 UNION ALL SELECT d, c, e FROM T3")
}
- @Test(expected = classOf[TableException])
+ @Test
def testIntersectAll(): Unit = {
util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2")
}
@@ -63,7 +63,7 @@ class SetOperatorsTest extends TableTestBase {
util.verifyPlan("SELECT a, b, c FROM T1 INTERSECT SELECT d, c, e FROM T3")
}
- @Test(expected = classOf[TableException])
+ @Test
def testMinusAll(): Unit = {
util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2")
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala
index c867de8..4743404 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala
@@ -19,7 +19,6 @@
package org.apache.flink.table.plan.rules.logical
import org.apache.flink.api.scala._
-import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.CalciteConfig
import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram,
FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE}
@@ -57,11 +56,6 @@ class ReplaceIntersectWithSemiJoinRuleTest extends TableTestBase {
}
@Test
- def testIntersectAll(): Unit = {
- util.verifyPlanNotExpected("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2", "joinType=[semi]")
- }
-
- @Test
def testIntersect(): Unit = {
util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2")
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
index f5b6fbe..c66739f 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
@@ -55,11 +55,6 @@ class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
}
@Test
- def testExceptAll(): Unit = {
- util.verifyPlanNotExpected("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2", "joinType=[anti]")
- }
-
- @Test
def testExcept(): Unit = {
util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2")
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.scala
similarity index 72%
copy from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
copy to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.scala
index f5b6fbe..30baf4d 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteIntersectAllRuleTest.scala
@@ -28,9 +28,9 @@ import org.apache.calcite.tools.RuleSets
import org.junit.{Before, Test}
/**
- * Test for [[ReplaceMinusWithAntiJoinRule]].
+ * Test for [[RewriteIntersectAllRule]].
*/
-class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
+class RewriteIntersectAllRuleTest extends TableTestBase {
private val util = batchTestUtil()
@@ -42,7 +42,7 @@ class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
- .add(RuleSets.ofList(ReplaceMinusWithAntiJoinRule.INSTANCE))
+ .add(RuleSets.ofList(RewriteIntersectAllRule.INSTANCE))
.build()
)
val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig)
@@ -55,28 +55,24 @@ class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
}
@Test
- def testExceptAll(): Unit = {
- util.verifyPlanNotExpected("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2", "joinType=[anti]")
+ def testIntersectAll(): Unit = {
+ util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2")
}
@Test
- def testExcept(): Unit = {
- util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2")
+ def testIntersectAllWithFilter(): Unit = {
+ util.verifyPlan(
+ "SELECT c FROM ((SELECT * FROM T1) INTERSECT ALL (SELECT * FROM T2)) WHERE a > 1")
}
@Test
- def testExceptWithFilter(): Unit = {
- util.verifyPlan("SELECT c FROM (SELECT * FROM T1 EXCEPT (SELECT * FROM T2)) WHERE b < 2")
+ def testIntersectAllLeftIsEmpty(): Unit = {
+ util.verifyPlan("SELECT c FROM T1 WHERE 1=0 INTERSECT ALL SELECT f FROM T2")
}
@Test
- def testExceptLeftIsEmpty(): Unit = {
- util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2")
- }
-
- @Test
- def testExceptRightIsEmpty(): Unit = {
- util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2 WHERE 1=0")
+ def testIntersectAllRightIsEmpty(): Unit = {
+ util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2 WHERE 1=0")
}
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.scala
similarity index 73%
copy from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
copy to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.scala
index f5b6fbe..c32a11b 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/RewriteMinusAllRuleTest.scala
@@ -28,9 +28,9 @@ import org.apache.calcite.tools.RuleSets
import org.junit.{Before, Test}
/**
- * Test for [[ReplaceMinusWithAntiJoinRule]].
+ * Test for [[RewriteMinusAllRule]].
*/
-class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
+class RewriteMinusAllRuleTest extends TableTestBase {
private val util = batchTestUtil()
@@ -42,7 +42,7 @@ class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
- .add(RuleSets.ofList(ReplaceMinusWithAntiJoinRule.INSTANCE))
+ .add(RuleSets.ofList(RewriteMinusAllRule.INSTANCE))
.build()
)
val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig)
@@ -56,27 +56,21 @@ class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase {
@Test
def testExceptAll(): Unit = {
- util.verifyPlanNotExpected("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2", "joinType=[anti]")
+ util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2")
}
@Test
- def testExcept(): Unit = {
- util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2")
+ def testExceptAllWithFilter(): Unit = {
+ util.verifyPlan("SELECT c FROM (SELECT * FROM T1 EXCEPT ALL (SELECT * FROM T2)) WHERE b < 2")
}
@Test
- def testExceptWithFilter(): Unit = {
- util.verifyPlan("SELECT c FROM (SELECT * FROM T1 EXCEPT (SELECT * FROM T2)) WHERE b < 2")
+ def testExceptAllLeftIsEmpty(): Unit = {
+ util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT ALL SELECT f FROM T2")
}
@Test
- def testExceptLeftIsEmpty(): Unit = {
- util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2")
+ def testExceptAllRightIsEmpty(): Unit = {
+ util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2 WHERE 1=0")
}
-
- @Test
- def testExceptRightIsEmpty(): Unit = {
- util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2 WHERE 1=0")
- }
-
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala
index 85f4d40..c78745b 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala
@@ -50,7 +50,7 @@ class SetOperatorsTest extends TableTestBase {
util.verifyPlan("SELECT a, b, c FROM T1 UNION ALL SELECT d, c, e FROM T3")
}
- @Test(expected = classOf[TableException])
+ @Test
def testIntersectAll(): Unit = {
util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2")
}
@@ -61,7 +61,7 @@ class SetOperatorsTest extends TableTestBase {
util.verifyPlan("SELECT a, b, c FROM T1 INTERSECT SELECT d, c, e FROM T3")
}
- @Test(expected = classOf[TableException])
+ @Test
def testMinusAll(): Unit = {
util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2")
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala
new file mode 100644
index 0000000..1a5e6c8
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.batch.sql
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.TableConfigOptions
+import org.apache.flink.table.runtime.batch.sql.join.JoinITCaseHelper
+import org.apache.flink.table.runtime.batch.sql.join.JoinType.{JoinType, _}
+import org.apache.flink.table.runtime.utils.BatchTestBase.row
+import org.apache.flink.table.runtime.utils.TestData._
+import org.apache.flink.table.runtime.utils.{BatchScalaTableEnvUtil, BatchTableEnvUtil, BatchTestBase}
+
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{Before, Test}
+
+import java.util
+
+import scala.util.Random
+
+@RunWith(classOf[Parameterized])
+class SetOperatorsITCase(joinType: JoinType) extends BatchTestBase {
+
+ @Before
+ def before(): Unit = {
+ tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3)
+ registerCollection("AllNullTable3", allNullData3, type3, "a, b, c")
+ registerCollection("SmallTable3", smallData3, type3, "a, b, c")
+ registerCollection("Table3", data3, type3, "a, b, c")
+ registerCollection("Table5", data5, type5, "a, b, c, d, e")
+ JoinITCaseHelper.disableOtherJoinOpForJoin(tEnv, joinType)
+ }
+
+ @Test
+ def testIntersect(): Unit = {
+ val data = List(
+ row(1, 1L, "Hi"),
+ row(2, 2L, "Hello"),
+ row(2, 2L, "Hello"),
+ row(3, 2L, "Hello world!")
+ )
+ val shuffleData = Random.shuffle(data)
+ BatchTableEnvUtil.registerCollection(
+ tEnv,
+ "T",
+ shuffleData,
+ new RowTypeInfo(INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO),
+ "a, b, c")
+
+ checkResult(
+ "SELECT c FROM SmallTable3 INTERSECT SELECT c FROM T",
+ Seq(row("Hi"), row("Hello")))
+ }
+
+ @Test
+ def testIntersectWithFilter(): Unit = {
+ checkResult(
+ "SELECT c FROM ((SELECT * FROM SmallTable3) INTERSECT (SELECT * FROM Table3)) WHERE a > 1",
+ Seq(row("Hello"), row("Hello world")))
+ }
+
+ @Test
+ def testExcept(): Unit = {
+ val data = List(row(1, 1L, "Hi"))
+ BatchTableEnvUtil.registerCollection(
+ tEnv,
+ "T", data,
+ new RowTypeInfo(INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO),
+ "a, b, c")
+
+ checkResult(
+ "SELECT c FROM SmallTable3 EXCEPT (SELECT c FROM T)",
+ Seq(row("Hello"), row("Hello world")))
+ }
+
+ @Test
+ def testExceptWithFilter(): Unit = {
+ checkResult(
+ "SELECT c FROM (" +
+ "SELECT * FROM SmallTable3 EXCEPT (SELECT a, b, d FROM Table5))" +
+ "WHERE b < 2",
+ Seq(row("Hi")))
+ }
+
+ @Test
+ def testIntersectWithNulls(): Unit = {
+ checkResult(
+ "SELECT c FROM AllNullTable3 INTERSECT SELECT c FROM AllNullTable3",
+ Seq(row(null)))
+ }
+
+ @Test
+ def testExceptWithNulls(): Unit = {
+ checkResult(
+ "SELECT c FROM AllNullTable3 EXCEPT SELECT c FROM AllNullTable3",
+ Seq())
+ }
+
+ @Test
+ def testIntersectAll(): Unit = {
+ BatchScalaTableEnvUtil.registerCollection(tEnv, "T1", Seq(1, 1, 1, 2, 2), "c")
+ BatchScalaTableEnvUtil.registerCollection(tEnv, "T2", Seq(1, 2, 2, 2, 3), "c")
+ checkResult(
+ "SELECT c FROM T1 INTERSECT ALL SELECT c FROM T2",
+ Seq(row(1), row(2), row(2)))
+ }
+
+ @Test
+ def testMinusAll(): Unit = {
+ BatchScalaTableEnvUtil.registerCollection(tEnv, "T2", Seq((1, 1L, "Hi")), "a, b, c")
+ val t1 = "SELECT * FROM SmallTable3"
+ val t2 = "SELECT * FROM T2"
+ checkResult(
+ s"SELECT c FROM (($t1 UNION ALL $t1 UNION ALL $t1) EXCEPT ALL ($t2 UNION ALL $t2))",
+ Seq(
+ row("Hi"),
+ row("Hello"),
+ row("Hello"),
+ row("Hello"),
+ row("Hello world"),
+ row("Hello world"),
+ row("Hello world")))
+ }
+}
+
+object SetOperatorsITCase {
+ @Parameterized.Parameters(name = "{0}")
+ def parameters(): util.Collection[Array[_]] = {
+ util.Arrays.asList(
+ // TODO
+// Array(BroadcastHashJoin),
+// Array(HashJoin),
+// Array(NestedLoopJoin),
+ Array(SortMergeJoin))
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala
new file mode 100644
index 0000000..7aae826
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.table.runtime.utils.{StreamingWithStateTestBase, TestData, TestingRetractSink}
+import org.apache.flink.types.Row
+
+import org.junit.Assert.assertEquals
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class SetOperatorsITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) {
+
+ @Test
+ def testIntersect(): Unit = {
+ val tableA = failingDataSource(TestData.smallTupleData3)
+ .toTable(tEnv, 'a1, 'a2, 'a3)
+ val tableB = failingDataSource(TestData.tupleData3)
+ .toTable(tEnv, 'b1, 'b2, 'b3)
+ tEnv.registerTable("A", tableA)
+ tEnv.registerTable("B", tableB)
+
+ val sqlQuery = "SELECT a1, a2, a3 from A INTERSECT SELECT b1, b2, b3 from B"
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ val expected = mutable.MutableList(
+ "1,1,Hi",
+ "2,2,Hello",
+ "3,2,Hello world")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testExcept(): Unit = {
+ val data1 = new mutable.MutableList[(Int, Long, String)]
+ data1.+=((1, 1L, "Hi1"))
+ data1.+=((1, 2L, "Hi2"))
+ data1.+=((1, 2L, "Hi2"))
+ data1.+=((1, 5L, "Hi3"))
+ data1.+=((2, 7L, "Hi5"))
+ data1.+=((1, 9L, "Hi6"))
+ data1.+=((1, 8L, "Hi8"))
+ data1.+=((3, 8L, "Hi9"))
+
+ val data2 = new mutable.MutableList[(Int, Long, String)]
+ data2.+=((1, 1L, "Hi1"))
+ data2.+=((2, 2L, "Hi2"))
+ data2.+=((3, 2L, "Hi3"))
+
+ val t1 = failingDataSource(data1).toTable(tEnv, 'a1, 'a2, 'a3)
+ val t2 = failingDataSource(data2).toTable(tEnv, 'b1, 'b2, 'b3)
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val sqlQuery = "SELECT a3 from T1 EXCEPT SELECT b3 from T2"
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ val expected = mutable.MutableList(
+ "Hi5", "Hi6", "Hi8", "Hi9"
+ )
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testIntersectAll(): Unit = {
+ val t1 = failingDataSource(Seq(1, 1, 1, 2, 2)).toTable(tEnv, 'c)
+ val t2 = failingDataSource(Seq(1, 2, 2, 2, 3)).toTable(tEnv, 'c)
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val sqlQuery = "SELECT c FROM T1 INTERSECT ALL SELECT c FROM T2"
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ val expected = mutable.MutableList("1", "2", "2")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testMinusAll(): Unit = {
+ val tableA = failingDataSource(TestData.smallTupleData3).toTable(tEnv, 'a, 'b, 'c)
+ tEnv.registerTable("tableA", tableA)
+ val tableB = failingDataSource(Seq((1, 1L, "Hi"), (1, 1L, "Hi"))).toTable(tEnv, 'a, 'b, 'c)
+ tEnv.registerTable("tableB", tableB)
+
+ val t1 = "SELECT * FROM tableA"
+ val t2 = "SELECT * FROM tableB"
+ val sqlQuery =
+ s"SELECT c FROM (($t1 UNION ALL $t1 UNION ALL $t1) EXCEPT ALL $t2)"
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ val expected = mutable.MutableList(
+ "Hi",
+ "Hello",
+ "Hello",
+ "Hello",
+ "Hello world",
+ "Hello world",
+ "Hello world"
+ )
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+}