You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/08/08 13:14:09 UTC

[flink] branch release-1.9 updated: [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.9 by this push:
     new 47d0fed  [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin
47d0fed is described below

commit 47d0fed51999ccea847b3fa6b6645a6adad54843
Author: godfreyhe <go...@163.com>
AuthorDate: Fri Aug 2 11:50:28 2019 +0800

    [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin
    
    This closes #9329
---
 .../rules/logical/FlinkJoinToMultiJoinRule.java    | 594 +++++++++++++++++++++
 .../planner/plan/rules/FlinkBatchRuleSets.scala    |   2 +-
 .../planner/plan/rules/FlinkStreamRuleSets.scala   |   2 +-
 .../rules/logical/FlinkJoinToMultiJoinRuleTest.xml |  81 +++
 .../logical/FlinkJoinToMultiJoinRuleTest.scala     |  72 +++
 5 files changed, 749 insertions(+), 2 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java
new file mode 100644
index 0000000..0d4a954
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java
@@ -0,0 +1,594 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.rules.MultiJoin;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Pair;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.JoinToMultiJoinRule}.
+ * This file should be removed while upgrading Calcite version to 1.21. [CALCITE-3225]
+ * Modification:
+ * - Does not match SEMI/ANTI join. lines changed (142-145)
+ * - lines changed (440-451)
+ */
+
+/**
+ * Planner rule to flatten a tree of
+ * {@link org.apache.calcite.rel.logical.LogicalJoin}s
+ * into a single {@link MultiJoin} with N inputs.
+ *
+ * <p>An input is not flattened if
+ * the input is a null generating input in an outer join, i.e., either input in
+ * a full outer join, the right hand side of a left outer join, or the left hand
+ * side of a right outer join.
+ *
+ * <p>Join conditions are also pulled up from the inputs into the topmost
+ * {@link MultiJoin},
+ * unless the input corresponds to a null generating input in an outer join,
+ *
+ * <p>Outer join information is also stored in the {@link MultiJoin}. A
+ * boolean flag indicates if the join is a full outer join, and in the case of
+ * left and right outer joins, the join type and outer join conditions are
+ * stored in arrays in the {@link MultiJoin}. This outer join information is
+ * associated with the null generating input in the outer join. So, in the case
+ * of a a left outer join between A and B, the information is associated with B,
+ * not A.
+ *
+ * <p>Here are examples of the {@link MultiJoin}s constructed after this rule
+ * has been applied on following join trees.
+ *
+ * <ul>
+ * <li>A JOIN B &rarr; MJ(A, B)
+ *
+ * <li>A JOIN B JOIN C &rarr; MJ(A, B, C)
+ *
+ * <li>A LEFT JOIN B &rarr; MJ(A, B), left outer join on input#1
+ *
+ * <li>A RIGHT JOIN B &rarr; MJ(A, B), right outer join on input#0
+ *
+ * <li>A FULL JOIN B &rarr; MJ[full](A, B)
+ *
+ * <li>A LEFT JOIN (B JOIN C) &rarr; MJ(A, MJ(B, C))), left outer join on
+ * input#1 in the outermost MultiJoin
+ *
+ * <li>(A JOIN B) LEFT JOIN C &rarr; MJ(A, B, C), left outer join on input#2
+ *
+ * <li>(A LEFT JOIN B) JOIN C &rarr; MJ(MJ(A, B), C), left outer join on input#1
+ * of the inner MultiJoin        TODO
+ *
+ * <li>A LEFT JOIN (B FULL JOIN C) &rarr; MJ(A, MJ[full](B, C)), left outer join
+ * on input#1 in the outermost MultiJoin
+ *
+ * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) &rarr;
+ *      MJ[full](MJ(A, B), MJ(C, D)), left outer join on input #1 in the first
+ *      inner MultiJoin and right outer join on input#0 in the second inner
+ *      MultiJoin
+ * </ul>
+ *
+ * <p>The constructor is parameterized to allow any sub-class of
+ * {@link org.apache.calcite.rel.core.Join}, not just
+ * {@link org.apache.calcite.rel.logical.LogicalJoin}.</p>
+ *
+ * @see org.apache.calcite.rel.rules.FilterMultiJoinMergeRule
+ * @see org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule
+ */
+public class FlinkJoinToMultiJoinRule extends RelOptRule {
+	public static final FlinkJoinToMultiJoinRule INSTANCE =
+			new FlinkJoinToMultiJoinRule(LogicalJoin.class, RelFactories.LOGICAL_BUILDER);
+
+	//~ Constructors -----------------------------------------------------------
+
+	@Deprecated // to be removed before 2.0
+	public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz) {
+		this(clazz, RelFactories.LOGICAL_BUILDER);
+	}
+
+	/**
+	 * Creates a FlinkJoinToMultiJoinRule.
+	 */
+	public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz,
+			RelBuilderFactory relBuilderFactory) {
+		super(
+				operand(clazz,
+						operand(RelNode.class, any()),
+						operand(RelNode.class, any())),
+				relBuilderFactory, null);
+	}
+
+	//~ Methods ----------------------------------------------------------------
+
+	@Override
+	public boolean matches(RelOptRuleCall call) {
+		final Join origJoin = call.rel(0);
+		return origJoin.getJoinType() != JoinRelType.SEMI && origJoin.getJoinType() != JoinRelType.ANTI;
+	}
+
+	public void onMatch(RelOptRuleCall call) {
+		final Join origJoin = call.rel(0);
+		final RelNode left = call.rel(1);
+		final RelNode right = call.rel(2);
+
+		// combine the children MultiJoin inputs into an array of inputs
+		// for the new MultiJoin
+		final List<ImmutableBitSet> projFieldsList = new ArrayList<>();
+		final List<int[]> joinFieldRefCountsList = new ArrayList<>();
+		final List<RelNode> newInputs =
+				combineInputs(
+						origJoin,
+						left,
+						right,
+						projFieldsList,
+						joinFieldRefCountsList);
+
+		// combine the outer join information from the left and right
+		// inputs, and include the outer join information from the current
+		// join, if it's a left/right outer join
+		final List<Pair<JoinRelType, RexNode>> joinSpecs = new ArrayList<>();
+		combineOuterJoins(
+				origJoin,
+				newInputs,
+				left,
+				right,
+				joinSpecs);
+
+		// pull up the join filters from the children MultiJoinRels and
+		// combine them with the join filter associated with this LogicalJoin to
+		// form the join filter for the new MultiJoin
+		List<RexNode> newJoinFilters = combineJoinFilters(origJoin, left, right);
+
+		// add on the join field reference counts for the join condition
+		// associated with this LogicalJoin
+		final com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap =
+				addOnJoinFieldRefCounts(newInputs,
+						origJoin.getRowType().getFieldCount(),
+						origJoin.getCondition(),
+						joinFieldRefCountsList);
+
+		List<RexNode> newPostJoinFilters =
+				combinePostJoinFilters(origJoin, left, right);
+
+		final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder();
+		RelNode multiJoin =
+				new MultiJoin(
+						origJoin.getCluster(),
+						newInputs,
+						RexUtil.composeConjunction(rexBuilder, newJoinFilters),
+						origJoin.getRowType(),
+						origJoin.getJoinType() == JoinRelType.FULL,
+						Pair.right(joinSpecs),
+						Pair.left(joinSpecs),
+						projFieldsList,
+						newJoinFieldRefCountsMap,
+						RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true));
+
+		call.transformTo(multiJoin);
+	}
+
+	/**
+	 * Combines the inputs into a LogicalJoin into an array of inputs.
+	 *
+	 * @param join                   original join
+	 * @param left                   left input into join
+	 * @param right                  right input into join
+	 * @param projFieldsList         returns a list of the new combined projection
+	 *                               fields
+	 * @param joinFieldRefCountsList returns a list of the new combined join
+	 *                               field reference counts
+	 * @return combined left and right inputs in an array
+	 */
+	private List<RelNode> combineInputs(
+			Join join,
+			RelNode left,
+			RelNode right,
+			List<ImmutableBitSet> projFieldsList,
+			List<int[]> joinFieldRefCountsList) {
+		final List<RelNode> newInputs = new ArrayList<>();
+
+		// leave the null generating sides of an outer join intact; don't
+		// pull up those children inputs into the array we're constructing
+		if (canCombine(left, join.getJoinType().generatesNullsOnLeft())) {
+			final MultiJoin leftMultiJoin = (MultiJoin) left;
+			for (int i = 0; i < left.getInputs().size(); i++) {
+				newInputs.add(leftMultiJoin.getInput(i));
+				projFieldsList.add(leftMultiJoin.getProjFields().get(i));
+				joinFieldRefCountsList.add(
+						leftMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
+			}
+		} else {
+			newInputs.add(left);
+			projFieldsList.add(null);
+			joinFieldRefCountsList.add(
+					new int[left.getRowType().getFieldCount()]);
+		}
+
+		if (canCombine(right, join.getJoinType().generatesNullsOnRight())) {
+			final MultiJoin rightMultiJoin = (MultiJoin) right;
+			for (int i = 0; i < right.getInputs().size(); i++) {
+				newInputs.add(rightMultiJoin.getInput(i));
+				projFieldsList.add(
+						rightMultiJoin.getProjFields().get(i));
+				joinFieldRefCountsList.add(
+						rightMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
+			}
+		} else {
+			newInputs.add(right);
+			projFieldsList.add(null);
+			joinFieldRefCountsList.add(
+					new int[right.getRowType().getFieldCount()]);
+		}
+
+		return newInputs;
+	}
+
+	/**
+	 * Combines the outer join conditions and join types from the left and right
+	 * join inputs. If the join itself is either a left or right outer join,
+	 * then the join condition corresponding to the join is also set in the
+	 * position corresponding to the null-generating input into the join. The
+	 * join type is also set.
+	 *
+	 * @param joinRel        join rel
+	 * @param combinedInputs the combined inputs to the join
+	 * @param left           left child of the joinrel
+	 * @param right          right child of the joinrel
+	 * @param joinSpecs      the list where the join types and conditions will be
+	 *                       copied
+	 */
+	private void combineOuterJoins(
+			Join joinRel,
+			List<RelNode> combinedInputs,
+			RelNode left,
+			RelNode right,
+			List<Pair<JoinRelType, RexNode>> joinSpecs) {
+		JoinRelType joinType = joinRel.getJoinType();
+		boolean leftCombined =
+				canCombine(left, joinType.generatesNullsOnLeft());
+		boolean rightCombined =
+				canCombine(right, joinType.generatesNullsOnRight());
+		switch (joinType) {
+			case LEFT:
+				if (leftCombined) {
+					copyOuterJoinInfo(
+							(MultiJoin) left,
+							joinSpecs,
+							0,
+							null,
+							null);
+				} else {
+					joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+				}
+				joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
+				break;
+			case RIGHT:
+				joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
+				if (rightCombined) {
+					copyOuterJoinInfo(
+							(MultiJoin) right,
+							joinSpecs,
+							left.getRowType().getFieldCount(),
+							right.getRowType().getFieldList(),
+							joinRel.getRowType().getFieldList());
+				} else {
+					joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+				}
+				break;
+			default:
+				if (leftCombined) {
+					copyOuterJoinInfo(
+							(MultiJoin) left,
+							joinSpecs,
+							0,
+							null,
+							null);
+				} else {
+					joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+				}
+				if (rightCombined) {
+					copyOuterJoinInfo(
+							(MultiJoin) right,
+							joinSpecs,
+							left.getRowType().getFieldCount(),
+							right.getRowType().getFieldList(),
+							joinRel.getRowType().getFieldList());
+				} else {
+					joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
+				}
+		}
+	}
+
+	/**
+	 * Copies outer join data from a source MultiJoin to a new set of arrays.
+	 * Also adjusts the conditions to reflect the new position of an input if
+	 * that input ends up being shifted to the right.
+	 *
+	 * @param multiJoin     the source MultiJoin
+	 * @param destJoinSpecs    the list where the join types and conditions will
+	 *                         be copied
+	 * @param adjustmentAmount if &gt; 0, the amount the RexInputRefs in the join
+	 *                         conditions need to be adjusted by
+	 * @param srcFields        the source fields that the original join conditions
+	 *                         are referencing
+	 * @param destFields       the destination fields that the new join conditions
+	 */
+	private void copyOuterJoinInfo(
+			MultiJoin multiJoin,
+			List<Pair<JoinRelType, RexNode>> destJoinSpecs,
+			int adjustmentAmount,
+			List<RelDataTypeField> srcFields,
+			List<RelDataTypeField> destFields) {
+		final List<Pair<JoinRelType, RexNode>> srcJoinSpecs =
+				Pair.zip(
+						multiJoin.getJoinTypes(),
+						multiJoin.getOuterJoinConditions());
+
+		if (adjustmentAmount == 0) {
+			destJoinSpecs.addAll(srcJoinSpecs);
+		} else {
+			assert srcFields != null;
+			assert destFields != null;
+			int nFields = srcFields.size();
+			int[] adjustments = new int[nFields];
+			for (int idx = 0; idx < nFields; idx++) {
+				adjustments[idx] = adjustmentAmount;
+			}
+			for (Pair<JoinRelType, RexNode> src
+					: srcJoinSpecs) {
+				destJoinSpecs.add(
+						Pair.of(
+								src.left,
+								src.right == null
+										? null
+										: src.right.accept(
+										new RelOptUtil.RexInputConverter(
+												multiJoin.getCluster().getRexBuilder(),
+												srcFields, destFields, adjustments))));
+			}
+		}
+	}
+
+	/**
+	 * Combines the join filters from the left and right inputs (if they are
+	 * MultiJoinRels) with the join filter in the joinrel into a single AND'd
+	 * join filter, unless the inputs correspond to null generating inputs in an
+	 * outer join.
+	 *
+	 * @param joinRel join rel
+	 * @param left    left child of the join
+	 * @param right   right child of the join
+	 * @return combined join filters AND-ed together
+	 */
+	private List<RexNode> combineJoinFilters(
+			Join joinRel,
+			RelNode left,
+			RelNode right) {
+		JoinRelType joinType = joinRel.getJoinType();
+
+		// AND the join condition if this isn't a left or right outer join;
+		// in those cases, the outer join condition is already tracked
+		// separately
+		final List<RexNode> filters = new ArrayList<>();
+		if ((joinType != JoinRelType.LEFT) && (joinType != JoinRelType.RIGHT)) {
+			filters.add(joinRel.getCondition());
+		}
+		if (canCombine(left, joinType.generatesNullsOnLeft())) {
+			filters.add(((MultiJoin) left).getJoinFilter());
+		}
+		// Need to adjust the RexInputs of the right child, since
+		// those need to shift over to the right
+		if (canCombine(right, joinType.generatesNullsOnRight())) {
+			MultiJoin multiJoin = (MultiJoin) right;
+			filters.add(
+					shiftRightFilter(joinRel, left, multiJoin,
+							multiJoin.getJoinFilter()));
+		}
+
+		return filters;
+	}
+
+	/**
+	 * Returns whether an input can be merged into a given relational expression
+	 * without changing semantics.
+	 *
+	 * @param input          input into a join
+	 * @param nullGenerating true if the input is null generating
+	 * @return true if the input can be combined into a parent MultiJoin
+	 */
+	private boolean canCombine(RelNode input, boolean nullGenerating) {
+		return input instanceof MultiJoin
+				&& !((MultiJoin) input).isFullOuterJoin()
+				&& !(containsOuter((MultiJoin) input))
+				&& !nullGenerating;
+	}
+
+	private boolean containsOuter(MultiJoin multiJoin) {
+		for (JoinRelType joinType : multiJoin.getJoinTypes()) {
+			if (joinType.isOuterJoin()) {
+				return true;
+			}
+		}
+		return false;
+	}
+
+	/**
+	 * Shifts a filter originating from the right child of the LogicalJoin to the
+	 * right, to reflect the filter now being applied on the resulting
+	 * MultiJoin.
+	 *
+	 * @param joinRel     the original LogicalJoin
+	 * @param left        the left child of the LogicalJoin
+	 * @param right       the right child of the LogicalJoin
+	 * @param rightFilter the filter originating from the right child
+	 * @return the adjusted right filter
+	 */
+	private RexNode shiftRightFilter(
+			Join joinRel,
+			RelNode left,
+			MultiJoin right,
+			RexNode rightFilter) {
+		if (rightFilter == null) {
+			return null;
+		}
+
+		int nFieldsOnLeft = left.getRowType().getFieldList().size();
+		int nFieldsOnRight = right.getRowType().getFieldList().size();
+		int[] adjustments = new int[nFieldsOnRight];
+		for (int i = 0; i < nFieldsOnRight; i++) {
+			adjustments[i] = nFieldsOnLeft;
+		}
+		rightFilter =
+				rightFilter.accept(
+						new RelOptUtil.RexInputConverter(
+								joinRel.getCluster().getRexBuilder(),
+								right.getRowType().getFieldList(),
+								joinRel.getRowType().getFieldList(),
+								adjustments));
+		return rightFilter;
+	}
+
+	/**
+	 * Adds on to the existing join condition reference counts the references
+	 * from the new join condition.
+	 *
+	 * @param multiJoinInputs          inputs into the new MultiJoin
+	 * @param nTotalFields             total number of fields in the MultiJoin
+	 * @param joinCondition            the new join condition
+	 * @param origJoinFieldRefCounts   existing join condition reference counts
+	 *
+	 * @return Map containing the new join condition
+	 */
+	private com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts(
+			List<RelNode> multiJoinInputs,
+			int nTotalFields,
+			RexNode joinCondition,
+			List<int[]> origJoinFieldRefCounts) {
+		// count the input references in the join condition
+		int[] joinCondRefCounts = new int[nTotalFields];
+		joinCondition.accept(new FlinkJoinToMultiJoinRule.InputReferenceCounter(joinCondRefCounts));
+
+		// first, make a copy of the ref counters
+		final Map<Integer, int[]> refCountsMap = new HashMap<>();
+		int nInputs = multiJoinInputs.size();
+		int currInput = 0;
+		for (int[] origRefCounts : origJoinFieldRefCounts) {
+			refCountsMap.put(
+					currInput,
+					origRefCounts.clone());
+			currInput++;
+		}
+
+		// add on to the counts for each input into the MultiJoin the
+		// reference counts computed for the current join condition
+		currInput = -1;
+		int startField = 0;
+		int nFields = 0;
+		for (int i = 0; i < nTotalFields; i++) {
+			if (joinCondRefCounts[i] == 0) {
+				continue;
+			}
+			while (i >= (startField + nFields)) {
+				startField += nFields;
+				currInput++;
+				assert currInput < nInputs;
+				nFields =
+						multiJoinInputs.get(currInput).getRowType().getFieldCount();
+			}
+			int[] refCounts = refCountsMap.get(currInput);
+			refCounts[i - startField] += joinCondRefCounts[i];
+		}
+
+		final com.google.common.collect.ImmutableMap.Builder<Integer, ImmutableIntList> builder =
+				com.google.common.collect.ImmutableMap.builder();
+		for (Map.Entry<Integer, int[]> entry : refCountsMap.entrySet()) {
+			builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue()));
+		}
+		return builder.build();
+	}
+
+	/**
+	 * Combines the post-join filters from the left and right inputs (if they
+	 * are MultiJoinRels) into a single AND'd filter.
+	 *
+	 * @param joinRel the original LogicalJoin
+	 * @param left    left child of the LogicalJoin
+	 * @param right   right child of the LogicalJoin
+	 * @return combined post-join filters AND'd together
+	 */
+	private List<RexNode> combinePostJoinFilters(
+			Join joinRel,
+			RelNode left,
+			RelNode right) {
+		final List<RexNode> filters = new ArrayList<>();
+		if (right instanceof MultiJoin) {
+			final MultiJoin multiRight = (MultiJoin) right;
+			filters.add(
+					shiftRightFilter(joinRel, left, multiRight,
+							multiRight.getPostJoinFilter()));
+		}
+
+		if (left instanceof MultiJoin) {
+			filters.add(((MultiJoin) left).getPostJoinFilter());
+		}
+
+		return filters;
+	}
+
+	//~ Inner Classes ----------------------------------------------------------
+
+	/**
+	 * Visitor that keeps a reference count of the inputs used by an expression.
+	 */
+	private class InputReferenceCounter extends RexVisitorImpl<Void> {
+		private final int[] refCounts;
+
+		InputReferenceCounter(int[] refCounts) {
+			super(true);
+			this.refCounts = refCounts;
+		}
+
+		public Void visitInputRef(RexInputRef inputRef) {
+			refCounts[inputRef.getIndex()]++;
+			return null;
+		}
+	}
+}
+
+// End FlinkJoinToMultiJoinRule.java
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
index 0f48950..076242e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
@@ -223,7 +223,7 @@ object FlinkBatchRuleSets {
 
   val JOIN_REORDER_PERPARE_RULES: RuleSet = RuleSets.ofList(
     // merge join to MultiJoin
-    JoinToMultiJoinRule.INSTANCE,
+    FlinkJoinToMultiJoinRule.INSTANCE,
     // merge project to MultiJoin
     ProjectMultiJoinMergeRule.INSTANCE,
     // merge filter to MultiJoin
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index cb371ab..ce6b6e5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -210,7 +210,7 @@ object FlinkStreamRuleSets {
     // merge filter to MultiJoin
     FilterMultiJoinMergeRule.INSTANCE,
     // merge join to MultiJoin
-    JoinToMultiJoinRule.INSTANCE
+    FlinkJoinToMultiJoinRule.INSTANCE
   )
 
   val JOIN_REORDER_RULES: RuleSet = RuleSets.ofList(
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml
new file mode 100644
index 0000000..92a0a62
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml
@@ -0,0 +1,81 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+  <TestCase name="testDoesNotMatchAntiJoin">
+    <Resource name="sql">
+      <![CDATA[
+SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t
+WHERE NOT EXISTS (SELECT e FROM T3  WHERE a = e)
+      ]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalFilter(condition=[NOT(EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+  LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+}))], variablesSet=[[$cor0]])
+   +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
+      +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+         :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+         +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalJoin(condition=[=($0, $4)], joinType=[anti])
+   :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]])
+   :  :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+   :  +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+   +- LogicalProject(e=[$0])
+      +- LogicalFilter(condition=[true])
+         +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testDoesNotMatchSemiJoin">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(e=[$0])
+  LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+})])
+   +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
+      +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+         :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+         +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3])
++- LogicalJoin(condition=[=($0, $4)], joinType=[semi])
+   :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]])
+   :  :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]])
+   :  +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]])
+   +- LogicalProject(e=[$0])
+      +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]])
+]]>
+    </Resource>
+  </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala
new file mode 100644
index 0000000..4364704
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.rules.logical
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE}
+import org.apache.flink.table.planner.utils.{TableConfigUtils, TableTestBase}
+
+import org.apache.calcite.plan.hep.HepMatchOrder
+import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule
+import org.apache.calcite.tools.RuleSets
+import org.junit.{Before, Test}
+
+/**
+  * Test for [[FlinkJoinToMultiJoinRule]].
+  */
+class FlinkJoinToMultiJoinRuleTest extends TableTestBase {
+  private val util = batchTestUtil()
+
+  @Before
+  def setup(): Unit = {
+    util.buildBatchProgram(FlinkBatchProgram.DEFAULT_REWRITE)
+    val calciteConfig = TableConfigUtils.getCalciteConfig(util.tableEnv.getConfig)
+    calciteConfig.getBatchProgram.get.addLast(
+      "rules",
+      FlinkHepRuleSetProgramBuilder.newBuilder
+        .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
+        .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+        .add(RuleSets.ofList(
+          FlinkJoinToMultiJoinRule.INSTANCE,
+          ProjectMultiJoinMergeRule.INSTANCE))
+        .build()
+    )
+
+    util.addTableSource[(Int, Long)]("T1", 'a, 'b)
+    util.addTableSource[(Int, Long)]("T2", 'c, 'd)
+    util.addTableSource[(Int, Long)]("T3", 'e, 'f)
+  }
+
+  @Test
+  def testDoesNotMatchSemiJoin(): Unit = {
+    val sqlQuery =
+      "SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)"
+    util.verifyPlan(sqlQuery)
+  }
+
+  @Test
+  def testDoesNotMatchAntiJoin(): Unit = {
+    val sqlQuery =
+      """
+        |SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t
+        |WHERE NOT EXISTS (SELECT e FROM T3  WHERE a = e)
+      """.stripMargin
+    util.verifyPlan(sqlQuery)
+  }
+}