You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/08/08 13:05:54 UTC
[flink] branch master updated: [FLINK-13545] [table-planner-blink]
JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 1b1e319 [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin
1b1e319 is described below
commit 1b1e31944a6c62aad7e3a5854ee00af812702cf9
Author: godfreyhe <go...@163.com>
AuthorDate: Fri Aug 2 11:50:28 2019 +0800
[FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin
This closes #9329
---
.../rules/logical/FlinkJoinToMultiJoinRule.java | 594 +++++++++++++++++++++
.../planner/plan/rules/FlinkBatchRuleSets.scala | 2 +-
.../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +-
.../rules/logical/FlinkJoinToMultiJoinRuleTest.xml | 81 +++
.../logical/FlinkJoinToMultiJoinRuleTest.scala | 72 +++
5 files changed, 749 insertions(+), 2 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java
new file mode 100644
index 0000000..0d4a954
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java
@@ -0,0 +1,594 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.rules.MultiJoin;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Pair;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.JoinToMultiJoinRule}.
+ * This file should be removed while upgrading Calcite version to 1.21. [CALCITE-3225]
+ * Modification:
+ * - Does not match SEMI/ANTI join. lines changed (142-145)
+ * - lines changed (440-451)
+ */
+
+/**
+ * Planner rule to flatten a tree of
+ * {@link org.apache.calcite.rel.logical.LogicalJoin}s
+ * into a single {@link MultiJoin} with N inputs.
+ *
+ * <p>An input is not flattened if
+ * the input is a null generating input in an outer join, i.e., either input in
+ * a full outer join, the right hand side of a left outer join, or the left hand
+ * side of a right outer join.
+ *
+ * <p>Join conditions are also pulled up from the inputs into the topmost
+ * {@link MultiJoin},
+ * unless the input corresponds to a null generating input in an outer join,
+ *
+ * <p>Outer join information is also stored in the {@link MultiJoin}. A
+ * boolean flag indicates if the join is a full outer join, and in the case of
+ * left and right outer joins, the join type and outer join conditions are
+ * stored in arrays in the {@link MultiJoin}. This outer join information is
+ * associated with the null generating input in the outer join. So, in the case
+ * of a a left outer join between A and B, the information is associated with B,
+ * not A.
+ *
+ * <p>Here are examples of the {@link MultiJoin}s constructed after this rule
+ * has been applied on following join trees.
+ *
+ * <ul>
+ * <li>A JOIN B → MJ(A, B)
+ *
+ * <li>A JOIN B JOIN C → MJ(A, B, C)
+ *
+ * <li>A LEFT JOIN B → MJ(A, B), left outer join on input#1
+ *
+ * <li>A RIGHT JOIN B → MJ(A, B), right outer join on input#0
+ *
+ * <li>A FULL JOIN B → MJ[full](A, B)
+ *
+ * <li>A LEFT JOIN (B JOIN C) → MJ(A, MJ(B, C))), left outer join on
+ * input#1 in the outermost MultiJoin
+ *
+ * <li>(A JOIN B) LEFT JOIN C → MJ(A, B, C), left outer join on input#2
+ *
+ * <li>(A LEFT JOIN B) JOIN C → MJ(MJ(A, B), C), left outer join on input#1
+ * of the inner MultiJoin TODO
+ *
+ * <li>A LEFT JOIN (B FULL JOIN C) → MJ(A, MJ[full](B, C)), left outer join
+ * on input#1 in the outermost MultiJoin
+ *
+ * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) →
+ * MJ[full](MJ(A, B), MJ(C, D)), left outer join on input #1 in the first
+ * inner MultiJoin and right outer join on input#0 in the second inner
+ * MultiJoin
+ * </ul>
+ *
+ * <p>The constructor is parameterized to allow any sub-class of
+ * {@link org.apache.calcite.rel.core.Join}, not just
+ * {@link org.apache.calcite.rel.logical.LogicalJoin}.</p>
+ *
+ * @see org.apache.calcite.rel.rules.FilterMultiJoinMergeRule
+ * @see org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule
+ */
+public class FlinkJoinToMultiJoinRule extends RelOptRule {
+ public static final FlinkJoinToMultiJoinRule INSTANCE =
+ new FlinkJoinToMultiJoinRule(LogicalJoin.class, RelFactories.LOGICAL_BUILDER);
+
+ //~ Constructors -----------------------------------------------------------
+
+ @Deprecated // to be removed before 2.0
+ public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz) {
+ this(clazz, RelFactories.LOGICAL_BUILDER);
+ }
+
+ /**
+ * Creates a FlinkJoinToMultiJoinRule.
+ */
+ public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz,
+ RelBuilderFactory relBuilderFactory) {
+ super(
+ operand(clazz,
+ operand(RelNode.class, any()),
+ operand(RelNode.class, any())),
+ relBuilderFactory, null);
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ final Join origJoin = call.rel(0);
+ return origJoin.getJoinType() != JoinRelType.SEMI && origJoin.getJoinType() != JoinRelType.ANTI;
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ final Join origJoin = call.rel(0);
+ final RelNode left = call.rel(1);
+ final RelNode right = call.rel(2);
+
+ // combine the children MultiJoin inputs into an array of inputs
+ // for the new MultiJoin
+ final List<ImmutableBitSet> projFieldsList = new ArrayList<>();
+ final List<int[]> joinFieldRefCountsList = new ArrayList<>();
+ final List<RelNode> newInputs =
+ combineInputs(
+ origJoin,
+ left,
+ right,
+ projFieldsList,
+ joinFieldRefCountsList);
+
+ // combine the outer join information from the left and right
+ // inputs, and include the outer join information from the current
+ // join, if it's a left/right outer join
+ final List<Pair<JoinRelType, RexNode>> joinSpecs = new ArrayList<>();
+ combineOuterJoins(
+ origJoin,
+ newInputs,
+ left,
+ right,
+ joinSpecs);
+
+ // pull up the join filters from the children MultiJoinRels and
+ // combine them with the join filter associated with this LogicalJoin to
+ // form the join filter for the new MultiJoin
+ List<RexNode> newJoinFilters = combineJoinFilters(origJoin, left, right);
+
+ // add on the join field reference counts for the join condition
+ // associated with this LogicalJoin
+ final com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap =
+ addOnJoinFieldRefCounts(newInputs,
+ origJoin.getRowType().getFieldCount(),
+ origJoin.getCondition(),
+ joinFieldRefCountsList);
+
+ List<RexNode> newPostJoinFilters =
+ combinePostJoinFilters(origJoin, left, right);
+
+ final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder();
+ RelNode multiJoin =
+ new MultiJoin(
+ origJoin.getCluster(),
+ newInputs,
+ RexUtil.composeConjunction(rexBuilder, newJoinFilters),
+ origJoin.getRowType(),
+ origJoin.getJoinType() == JoinRelType.FULL,
+ Pair.right(joinSpecs),
+ Pair.left(joinSpecs),
+ projFieldsList,
+ newJoinFieldRefCountsMap,
+ RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true));
+
+ call.transformTo(multiJoin);
+ }
+
+ /**
+ * Combines the inputs into a LogicalJoin into an array of inputs.
+ *
+ * @param join original join
+ * @param left left input into join
+ * @param right right input into join
+ * @param projFieldsList returns a list of the new combined projection
+ * fields
+ * @param joinFieldRefCountsList returns a list of the new combined join
+ * field reference counts
+ * @return combined left and right inputs in an array
+ */
+ private List<RelNode> combineInputs(
+ Join join,
+ RelNode left,
+ RelNode right,
+ List<ImmutableBitSet> projFieldsList,
+ List<int[]> joinFieldRefCountsList) {
+ final List<RelNode> newInputs = new ArrayList<>();
+
+ // leave the null generating sides of an outer join intact; don't
+ // pull up those children inputs into the array we're constructing
+ if (canCombine(left, join.getJoinType().generatesNullsOnLeft())) {
+ final MultiJoin leftMultiJoin = (MultiJoin) left;
+ for (int i = 0; i < left.getInputs().size(); i++) {
+ newInputs.add(leftMultiJoin.getInput(i));
+ projFieldsList.add(leftMultiJoin.getProjFields().get(i));
+ joinFieldRefCountsList.add(
+ leftMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
+ }
+ } else {
+ newInputs.add(left);
+ projFieldsList.add(null);
+ joinFieldRefCountsList.add(
+ new int[left.getRowType().getFieldCount()]);
+ }
+
+ if (canCombine(right, join.getJoinType().generatesNullsOnRight())) {
+ final MultiJoin rightMultiJoin = (MultiJoin) right;
+ for (int i = 0; i < right.getInputs().size(); i++) {
+ newInputs.add(rightMultiJoin.getInput(i));
+ projFieldsList.add(
+ rightMultiJoin.getProjFields().get(i));
+ joinFieldRefCountsList.add(
+ rightMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
+ }
+ } else {
+ newInputs.add(right);
+ projFieldsList.add(null);
+ joinFieldRefCountsList.add(
+ new int[right.getRowType().getFieldCount()]);
+ }
+
+ return newInputs;
+ }
+
+ /**
+ * Combines the outer join conditions and join types from the left and right
+ * join inputs. If the join itself is either a left or right outer join,
+ * then the join condition corresponding to the join is also set in the
+ * position corresponding to the null-generating input into the join. The
+ * join type is also set.
+ *
+ * @param joinRel join rel
+ * @param combinedInputs the combined inputs to the join
+ * @param left left child of the joinrel
+ * @param right right child of the joinrel
+ * @param joinSpecs the list where the join types and conditions will be
+ * copied
+ */
+ private void combineOuterJoins(
+ Join joinRel,
+ List<RelNode> combinedInputs,
+ RelNode left,
+ RelNode right,
+ List<Pair<JoinRelType, RexNode>> joinSpecs) {
+ JoinRelType joinType = joinRel.getJoinType();
+ boolean leftCombined =
+ canCombine(left, joinType.generatesNullsOnLeft());
+ boolean rightCombined =
+ canCombine(right, joinType.generatesNullsOnRight());
+ switch (joinType) {
+ case LEFT:
+ if (leftCombined) {
+ copyOuterJoinInfo(
+ (MultiJoin) left,
+ joinSpecs,
+ 0,
+ null,
+ null);
+ } else {
+ joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+ }
+ joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
+ break;
+ case RIGHT:
+ joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
+ if (rightCombined) {
+ copyOuterJoinInfo(
+ (MultiJoin) right,
+ joinSpecs,
+ left.getRowType().getFieldCount(),
+ right.getRowType().getFieldList(),
+ joinRel.getRowType().getFieldList());
+ } else {
+ joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+ }
+ break;
+ default:
+ if (leftCombined) {
+ copyOuterJoinInfo(
+ (MultiJoin) left,
+ joinSpecs,
+ 0,
+ null,
+ null);
+ } else {
+ joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+ }
+ if (rightCombined) {
+ copyOuterJoinInfo(
+ (MultiJoin) right,
+ joinSpecs,
+ left.getRowType().getFieldCount(),
+ right.getRowType().getFieldList(),
+ joinRel.getRowType().getFieldList());
+ } else {
+ joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+ }
+ }
+ }
+
+ /**
+ * Copies outer join data from a source MultiJoin to a new set of arrays.
+ * Also adjusts the conditions to reflect the new position of an input if
+ * that input ends up being shifted to the right.
+ *
+ * @param multiJoin the source MultiJoin
+ * @param destJoinSpecs the list where the join types and conditions will
+ * be copied
+ * @param adjustmentAmount if > 0, the amount the RexInputRefs in the join
+ * conditions need to be adjusted by
+ * @param srcFields the source fields that the original join conditions
+ * are referencing
+ * @param destFields the destination fields that the new join conditions
+ */
+ private void copyOuterJoinInfo(
+ MultiJoin multiJoin,
+ List<Pair<JoinRelType, RexNode>> destJoinSpecs,
+ int adjustmentAmount,
+ List<RelDataTypeField> srcFields,
+ List<RelDataTypeField> destFields) {
+ final List<Pair<JoinRelType, RexNode>> srcJoinSpecs =
+ Pair.zip(
+ multiJoin.getJoinTypes(),
+ multiJoin.getOuterJoinConditions());
+
+ if (adjustmentAmount == 0) {
+ destJoinSpecs.addAll(srcJoinSpecs);
+ } else {
+ assert srcFields != null;
+ assert destFields != null;
+ int nFields = srcFields.size();
+ int[] adjustments = new int[nFields];
+ for (int idx = 0; idx < nFields; idx++) {
+ adjustments[idx] = adjustmentAmount;
+ }
+ for (Pair<JoinRelType, RexNode> src
+ : srcJoinSpecs) {
+ destJoinSpecs.add(
+ Pair.of(
+ src.left,
+ src.right == null
+ ? null
+ : src.right.accept(
+ new RelOptUtil.RexInputConverter(
+ multiJoin.getCluster().getRexBuilder(),
+ srcFields, destFields, adjustments))));
+ }
+ }
+ }
+
+ /**
+ * Combines the join filters from the left and right inputs (if they are
+ * MultiJoinRels) with the join filter in the joinrel into a single AND'd
+ * join filter, unless the inputs correspond to null generating inputs in an
+ * outer join.
+ *
+ * @param joinRel join rel
+ * @param left left child of the join
+ * @param right right child of the join
+ * @return combined join filters AND-ed together
+ */
+ private List<RexNode> combineJoinFilters(
+ Join joinRel,
+ RelNode left,
+ RelNode right) {
+ JoinRelType joinType = joinRel.getJoinType();
+
+ // AND the join condition if this isn't a left or right outer join;
+ // in those cases, the outer join condition is already tracked
+ // separately
+ final List<RexNode> filters = new ArrayList<>();
+ if ((joinType != JoinRelType.LEFT) && (joinType != JoinRelType.RIGHT)) {
+ filters.add(joinRel.getCondition());
+ }
+ if (canCombine(left, joinType.generatesNullsOnLeft())) {
+ filters.add(((MultiJoin) left).getJoinFilter());
+ }
+ // Need to adjust the RexInputs of the right child, since
+ // those need to shift over to the right
+ if (canCombine(right, joinType.generatesNullsOnRight())) {
+ MultiJoin multiJoin = (MultiJoin) right;
+ filters.add(
+ shiftRightFilter(joinRel, left, multiJoin,
+ multiJoin.getJoinFilter()));
+ }
+
+ return filters;
+ }
+
+ /**
+ * Returns whether an input can be merged into a given relational expression
+ * without changing semantics.
+ *
+ * @param input input into a join
+ * @param nullGenerating true if the input is null generating
+ * @return true if the input can be combined into a parent MultiJoin
+ */
+ private boolean canCombine(RelNode input, boolean nullGenerating) {
+ return input instanceof MultiJoin
+ && !((MultiJoin) input).isFullOuterJoin()
+ && !(containsOuter((MultiJoin) input))
+ && !nullGenerating;
+ }
+
+ private boolean containsOuter(MultiJoin multiJoin) {
+ for (JoinRelType joinType : multiJoin.getJoinTypes()) {
+ if (joinType.isOuterJoin()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Shifts a filter originating from the right child of the LogicalJoin to the
+ * right, to reflect the filter now being applied on the resulting
+ * MultiJoin.
+ *
+ * @param joinRel the original LogicalJoin
+ * @param left the left child of the LogicalJoin
+ * @param right the right child of the LogicalJoin
+ * @param rightFilter the filter originating from the right child
+ * @return the adjusted right filter
+ */
+ private RexNode shiftRightFilter(
+ Join joinRel,
+ RelNode left,
+ MultiJoin right,
+ RexNode rightFilter) {
+ if (rightFilter == null) {
+ return null;
+ }
+
+ int nFieldsOnLeft = left.getRowType().getFieldList().size();
+ int nFieldsOnRight = right.getRowType().getFieldList().size();
+ int[] adjustments = new int[nFieldsOnRight];
+ for (int i = 0; i < nFieldsOnRight; i++) {
+ adjustments[i] = nFieldsOnLeft;
+ }
+ rightFilter =
+ rightFilter.accept(
+ new RelOptUtil.RexInputConverter(
+ joinRel.getCluster().getRexBuilder(),
+ right.getRowType().getFieldList(),
+ joinRel.getRowType().getFieldList(),
+ adjustments));
+ return rightFilter;
+ }
+
+ /**
+ * Adds on to the existing join condition reference counts the references
+ * from the new join condition.
+ *
+ * @param multiJoinInputs inputs into the new MultiJoin
+ * @param nTotalFields total number of fields in the MultiJoin
+ * @param joinCondition the new join condition
+ * @param origJoinFieldRefCounts existing join condition reference counts
+ *
+ * @return Map containing the new join condition
+ */
+ private com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts(
+ List<RelNode> multiJoinInputs,
+ int nTotalFields,
+ RexNode joinCondition,
+ List<int[]> origJoinFieldRefCounts) {
+ // count the input references in the join condition
+ int[] joinCondRefCounts = new int[nTotalFields];
+ joinCondition.accept(new FlinkJoinToMultiJoinRule.InputReferenceCounter(joinCondRefCounts));
+
+ // first, make a copy of the ref counters
+ final Map<Integer, int[]> refCountsMap = new HashMap<>();
+ int nInputs = multiJoinInputs.size();
+ int currInput = 0;
+ for (int[] origRefCounts : origJoinFieldRefCounts) {
+ refCountsMap.put(
+ currInput,
+ origRefCounts.clone());
+ currInput++;
+ }
+
+ // add on to the counts for each input into the MultiJoin the
+ // reference counts computed for the current join condition
+ currInput = -1;
+ int startField = 0;
+ int nFields = 0;
+ for (int i = 0; i < nTotalFields; i++) {
+ if (joinCondRefCounts[i] == 0) {
+ continue;
+ }
+ while (i >= (startField + nFields)) {
+ startField += nFields;
+ currInput++;
+ assert currInput < nInputs;
+ nFields =
+ multiJoinInputs.get(currInput).getRowType().getFieldCount();
+ }
+ int[] refCounts = refCountsMap.get(currInput);
+ refCounts[i - startField] += joinCondRefCounts[i];
+ }
+
+ final com.google.common.collect.ImmutableMap.Builder<Integer, ImmutableIntList> builder =
+ com.google.common.collect.ImmutableMap.builder();
+ for (Map.Entry<Integer, int[]> entry : refCountsMap.entrySet()) {
+ builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue()));
+ }
+ return builder.build();
+ }
+
+ /**
+ * Combines the post-join filters from the left and right inputs (if they
+ * are MultiJoinRels) into a single AND'd filter.
+ *
+ * @param joinRel the original LogicalJoin
+ * @param left left child of the LogicalJoin
+ * @param right right child of the LogicalJoin
+ * @return combined post-join filters AND'd together
+ */
+ private List<RexNode> combinePostJoinFilters(
+ Join joinRel,
+ RelNode left,
+ RelNode right) {
+ final List<RexNode> filters = new ArrayList<>();
+ if (right instanceof MultiJoin) {
+ final MultiJoin multiRight = (MultiJoin) right;
+ filters.add(
+ shiftRightFilter(joinRel, left, multiRight,
+ multiRight.getPostJoinFilter()));
+ }
+
+ if (left instanceof MultiJoin) {
+ filters.add(((MultiJoin) left).getPostJoinFilter());
+ }
+
+ return filters;
+ }
+
+ //~ Inner Classes ----------------------------------------------------------
+
+ /**
+ * Visitor that keeps a reference count of the inputs used by an expression.
+ */
+ private class InputReferenceCounter extends RexVisitorImpl<Void> {
+ private final int[] refCounts;
+
+ InputReferenceCounter(int[] refCounts) {
+ super(true);
+ this.refCounts = refCounts;
+ }
+
+ public Void visitInputRef(RexInputRef inputRef) {
+ refCounts[inputRef.getIndex()]++;
+ return null;
+ }
+ }
+}
+
+// End FlinkJoinToMultiJoinRule.java
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
index 0f48950..076242e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
@@ -223,7 +223,7 @@ object FlinkBatchRuleSets {
val JOIN_REORDER_PERPARE_RULES: RuleSet = RuleSets.ofList(
// merge join to MultiJoin
- JoinToMultiJoinRule.INSTANCE,
+ FlinkJoinToMultiJoinRule.INSTANCE,
// merge project to MultiJoin
ProjectMultiJoinMergeRule.INSTANCE,
// merge filter to MultiJoin
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index beb00e5..25499b6 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -210,7 +210,7 @@ object FlinkStreamRuleSets {
// merge filter to MultiJoin
FilterMultiJoinMergeRule.INSTANCE,
// merge join to MultiJoin
- JoinToMultiJoinRule.INSTANCE
+ FlinkJoinToMultiJoinRule.INSTANCE
)
val JOIN_REORDER_RULES: RuleSet = RuleSets.ofList(
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml
new file mode 100644
index 0000000..92a0a62
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml
@@ -0,0 +1,81 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testDoesNotMatchAntiJoin">
+ <Resource name="sql">
+ <![CDATA[
+SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t
+WHERE NOT EXISTS (SELECT e FROM T3 WHERE a = e)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalFilter(condition=[NOT(EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+ LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
+ +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+ :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalJoin(condition=[=($0, $4)], joinType=[anti])
+ :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]])
+ : :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+ +- LogicalProject(e=[$0])
+ +- LogicalFilter(condition=[true])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDoesNotMatchSemiJoin">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(e=[$0])
+ LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+})])
+ +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
+ +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+ :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalJoin(condition=[=($0, $4)], joinType=[semi])
+ :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]])
+ : :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+ : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+ +- LogicalProject(e=[$0])
+ +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala
new file mode 100644
index 0000000..4364704
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.rules.logical
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE}
+import org.apache.flink.table.planner.utils.{TableConfigUtils, TableTestBase}
+
+import org.apache.calcite.plan.hep.HepMatchOrder
+import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule
+import org.apache.calcite.tools.RuleSets
+import org.junit.{Before, Test}
+
+/**
+ * Test for [[FlinkJoinToMultiJoinRule]].
+ */
+class FlinkJoinToMultiJoinRuleTest extends TableTestBase {
+ private val util = batchTestUtil()
+
+ @Before
+ def setup(): Unit = {
+ util.buildBatchProgram(FlinkBatchProgram.DEFAULT_REWRITE)
+ val calciteConfig = TableConfigUtils.getCalciteConfig(util.tableEnv.getConfig)
+ calciteConfig.getBatchProgram.get.addLast(
+ "rules",
+ FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(RuleSets.ofList(
+ FlinkJoinToMultiJoinRule.INSTANCE,
+ ProjectMultiJoinMergeRule.INSTANCE))
+ .build()
+ )
+
+ util.addTableSource[(Int, Long)]("T1", 'a, 'b)
+ util.addTableSource[(Int, Long)]("T2", 'c, 'd)
+ util.addTableSource[(Int, Long)]("T3", 'e, 'f)
+ }
+
+ @Test
+ def testDoesNotMatchSemiJoin(): Unit = {
+ val sqlQuery =
+ "SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)"
+ util.verifyPlan(sqlQuery)
+ }
+
+ @Test
+ def testDoesNotMatchAntiJoin(): Unit = {
+ val sqlQuery =
+ """
+ |SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t
+ |WHERE NOT EXISTS (SELECT e FROM T3 WHERE a = e)
+ """.stripMargin
+ util.verifyPlan(sqlQuery)
+ }
+}