You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ku...@apache.org on 2019/05/11 12:32:53 UTC
[flink] branch master updated: [FLINK-12371][table-planner-blink]
Add support for converting (NOT) IN / (NOT) EXISTS to semi / anti join
(#8317)
This is an automated email from the ASF dual-hosted git repository.
kurt pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 76ae39a [FLINK-12371][table-planner-blink] Add support for converting (NOT) IN / (NOT) EXISTS to semi / anti join (#8317)
76ae39a is described below
commit 76ae39a18ecb29c49a0ce9a205517b2104daa616
Author: godfrey he <go...@163.com>
AuthorDate: Sat May 11 20:32:33 2019 +0800
[FLINK-12371][table-planner-blink] Add support for converting (NOT) IN / (NOT) EXISTS to semi / anti join (#8317)
This closes #8317
---
.../java/org/apache/calcite/rel/core/Join.java | 341 ++
.../org/apache/calcite/rel/core/JoinRelType.java | 184 +
.../apache/calcite/sql2rel/RelDecorrelator.java | 2834 ++++++++++++
.../flink/table/api/PlannerConfigOptions.java | 6 +
.../plan/rules/logical/FlinkFilterJoinRule.java | 363 ++
.../logical/FlinkJoinPushExpressionsRule.java | 82 +
.../logical/FlinkProjectJoinTransposeRule.java | 151 +
.../plan/rules/logical/SubQueryDecorrelator.java | 1445 +++++++
.../apache/flink/table/plan/util/JoinTypeUtil.java | 38 +-
.../table/calcite/RelTimeIndicatorConverter.scala | 2 -
.../plan/metadata/FlinkRelMdColumnUniqueness.scala | 35 +-
.../plan/metadata/FlinkRelMdDistinctRowCount.scala | 32 +-
.../metadata/FlinkRelMdModifiedMonotonicity.scala | 8 +-
.../FlinkRelMdPercentageOriginalRows.scala | 21 +-
.../plan/metadata/FlinkRelMdPopulationSize.scala | 14 +-
.../table/plan/metadata/FlinkRelMdRowCount.scala | 16 +-
.../plan/metadata/FlinkRelMdSelectivity.scala | 27 +-
.../flink/table/plan/metadata/FlinkRelMdSize.scala | 19 +-
.../table/plan/metadata/FlinkRelMdUniqueKeys.scala | 20 +-
.../plan/nodes/common/CommonPhysicalJoin.scala | 45 +-
.../plan/nodes/logical/FlinkLogicalJoin.scala | 19 +-
.../nodes/physical/batch/BatchExecHashJoin.scala | 74 +-
.../nodes/physical/batch/BatchExecJoinBase.scala | 15 +-
.../physical/batch/BatchExecNestedLoopJoin.scala | 65 +-
.../physical/batch/BatchExecSortMergeJoin.scala | 68 +-
...reamExecJoinBase.scala => StreamExecJoin.scala} | 43 +-
.../physical/stream/StreamExecWindowJoin.scala | 3 +-
.../plan/optimize/program/FlinkBatchProgram.scala | 21 +-
.../plan/optimize/program/FlinkStreamProgram.scala | 11 +
.../plan/reuse/DeadlockBreakupProcessor.scala | 10 +-
.../table/plan/rules/FlinkBatchRuleSets.scala | 15 +-
.../table/plan/rules/FlinkStreamRuleSets.scala | 13 +-
.../rules/logical/FlinkSubQueryRemoveRule.scala | 459 ++
.../logical/SimplifyFilterConditionRule.scala | 103 +
.../physical/batch/BatchExecHashJoinRule.scala | 48 +-
.../physical/batch/BatchExecJoinRuleBase.scala | 25 +-
.../batch/BatchExecNestedLoopJoinRule.scala | 27 +-
.../batch/BatchExecSingleRowJoinRule.scala | 21 +-
.../batch/BatchExecSortMergeJoinRule.scala | 10 +-
.../rules/physical/stream/StreamExecJoinRule.scala | 8 +-
.../physical/stream/StreamExecWindowJoinRule.scala | 12 +-
.../flink/table/plan/util/FlinkRelMdUtil.scala | 30 +-
.../flink/table/plan/util/FlinkRelOptUtil.scala | 401 +-
.../table/plan/util/UpdatingPlanChecker.scala | 4 +-
.../table/plan/batch/sql/DeadlockBreakupTest.xml | 42 +-
.../table/plan/batch/sql/SubplanReuseTest.xml | 34 +-
.../sql/join/BroadcastHashSemiAntiJoinTest.xml | 1956 +++++++++
.../batch/sql/join/NestedLoopSemiAntiJoinTest.xml | 2612 +++++++++++
.../table/plan/batch/sql/join/SemiAntiJoinTest.xml | 2687 ++++++++++++
.../plan/batch/sql/join/ShuffledHashJoinTest.xml | 12 +-
.../sql/join/ShuffledHashSemiAntiJoinTest.xml | 2018 +++++++++
.../batch/sql/join/SortMergeSemiAntiJoinTest.xml | 2103 +++++++++
.../logical/FlinkJoinPushExpressionsRuleTest.xml | 129 +
.../logical/SimplifyFilterConditionRuleTest.xml | 226 +
.../logical/subquery/SubQueryAntiJoinTest.xml | 2254 ++++++++++
.../logical/subquery/SubQuerySemiJoinTest.xml | 4561 ++++++++++++++++++++
.../table/plan/stream/sql/SubplanReuseTest.xml | 8 +-
.../table/plan/stream/sql/agg/GroupingSetsTest.xml | 16 +-
.../flink/table/plan/stream/sql/join/JoinTest.xml | 15 +-
.../plan/stream/sql/join/SemiAntiJoinTest.xml | 2654 ++++++++++++
.../table/plan/batch/sql/DeadlockBreakupTest.scala | 1 -
.../sql/join/BroadcastHashSemiAntiJoinTest.scala | 179 +
.../sql/join/NestedLoopSemiAntiJoinTest.scala | 32 +
.../plan/batch/sql/join/SemiAntiJoinTest.scala | 26 +
.../plan/batch/sql/join/SemiAntiJoinTestBase.scala | 584 +++
.../sql/join/ShuffledHashSemiAntiJoinTest.scala | 185 +
.../batch/sql/join/SortMergeSemiAntiJoinTest.scala | 162 +
.../logical/FlinkJoinPushExpressionsRuleTest.scala | 77 +
.../logical/SimplifyFilterConditionRuleTest.scala | 110 +
.../logical/subquery/SubQueryAntiJoinTest.scala | 767 ++++
.../logical/subquery/SubQuerySemiJoinTest.scala | 1673 +++++++
.../rules/logical/subquery/SubQueryTestBase.scala | 43 +
.../SubqueryCorrelateVariablesValidationTest.scala | 128 +
.../plan/stream/sql/join/SemiAntiJoinTest.scala | 575 +++
74 files changed, 32592 insertions(+), 465 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/Join.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/Join.java
new file mode 100644
index 0000000..657efbf
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/Join.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.core;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.apache.calcite.rel.RelNode.Context;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptCost;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.BiRel;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelWriter;
+import org.apache.calcite.rel.metadata.RelMdUtil;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexChecker;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Util;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * This class is copied from https://github.com/apache/calcite/pull/1157 to supports SEMI/ANTI join.
+ * NOTES: This file should be deleted when upgrading to a new calcite version
+ * which contains CALCITE-2969.
+ */
+
+/**
+ * Relational expression that combines two relational expressions according to
+ * some condition.
+ *
+ * <p>Each output row has columns from the left and right inputs.
+ * The set of output rows is a subset of the cartesian product of the two
+ * inputs; precisely which subset depends on the join condition.
+ */
+public abstract class Join extends BiRel {
+ //~ Instance fields --------------------------------------------------------
+
+ protected final RexNode condition;
+ protected final ImmutableSet<CorrelationId> variablesSet;
+
+ /**
+ * Values must be of enumeration {@link JoinRelType}, except that
+ * {@link JoinRelType#RIGHT} is disallowed.
+ */
+ protected final JoinRelType joinType;
+
+ protected final JoinInfo joinInfo;
+
+ //~ Constructors -----------------------------------------------------------
+
+ // Next time we need to change the constructor of Join, let's change the
+ // "Set<String> variablesStopped" parameter to
+ // "Set<CorrelationId> variablesSet". At that point we would deprecate
+ // RelNode.getVariablesStopped().
+
+ /**
+ * Creates a Join.
+ *
+ * <p>Note: We plan to change the {@code variablesStopped} parameter to
+ * {@code Set<CorrelationId> variablesSet}
+ * {@link org.apache.calcite.util.Bug#upgrade(String) before version 2.0},
+ * because {@link #getVariablesSet()}
+ * is preferred over {@link #getVariablesStopped()}.
+ * This constructor is not deprecated, for now, because maintaining overloaded
+ * constructors in multiple sub-classes would be onerous.
+ *
+ * @param cluster Cluster
+ * @param traitSet Trait set
+ * @param left Left input
+ * @param right Right input
+ * @param condition Join condition
+ * @param joinType Join type
+ * @param variablesSet Set variables that are set by the
+ * LHS and used by the RHS and are not available to
+ * nodes above this Join in the tree
+ */
+ protected Join(
+ RelOptCluster cluster,
+ RelTraitSet traitSet,
+ RelNode left,
+ RelNode right,
+ RexNode condition,
+ Set<CorrelationId> variablesSet,
+ JoinRelType joinType) {
+ super(cluster, traitSet, left, right);
+ this.condition = Objects.requireNonNull(condition);
+ this.variablesSet = ImmutableSet.copyOf(variablesSet);
+ this.joinType = Objects.requireNonNull(joinType);
+ this.joinInfo = JoinInfo.of(left, right, condition);
+ }
+
+ @Deprecated // to be removed before 2.0
+ protected Join(
+ RelOptCluster cluster,
+ RelTraitSet traitSet,
+ RelNode left,
+ RelNode right,
+ RexNode condition,
+ JoinRelType joinType,
+ Set<String> variablesStopped) {
+ this(cluster, traitSet, left, right, condition,
+ CorrelationId.setOf(variablesStopped), joinType);
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ @Override public List<RexNode> getChildExps() {
+ return ImmutableList.of(condition);
+ }
+
+ @Override public RelNode accept(RexShuttle shuttle) {
+ RexNode condition = shuttle.apply(this.condition);
+ if (this.condition == condition) {
+ return this;
+ }
+ return copy(traitSet, condition, left, right, joinType, isSemiJoinDone());
+ }
+
+ public RexNode getCondition() {
+ return condition;
+ }
+
+ public JoinRelType getJoinType() {
+ return joinType;
+ }
+
+ @Override public boolean isValid(Litmus litmus, Context context) {
+ if (!super.isValid(litmus, context)) {
+ return false;
+ }
+ if (getRowType().getFieldCount()
+ != getSystemFieldList().size()
+ + left.getRowType().getFieldCount()
+ + (joinType.projectsRight() ? right.getRowType().getFieldCount() : 0)) {
+ return litmus.fail("field count mismatch");
+ }
+ if (condition != null) {
+ if (condition.getType().getSqlTypeName() != SqlTypeName.BOOLEAN) {
+ return litmus.fail("condition must be boolean: {}",
+ condition.getType());
+ }
+ // The input to the condition is a row type consisting of system
+ // fields, left fields, and right fields. Very similar to the
+ // output row type, except that fields have not yet been made due
+ // due to outer joins.
+ RexChecker checker =
+ new RexChecker(
+ getCluster().getTypeFactory().builder()
+ .addAll(getSystemFieldList())
+ .addAll(getLeft().getRowType().getFieldList())
+ .addAll(getRight().getRowType().getFieldList())
+ .build(),
+ context, litmus);
+ condition.accept(checker);
+ if (checker.getFailureCount() > 0) {
+ return litmus.fail(checker.getFailureCount()
+ + " failures in condition " + condition);
+ }
+ }
+ return litmus.succeed();
+ }
+
+ @Override public RelOptCost computeSelfCost(RelOptPlanner planner,
+ RelMetadataQuery mq) {
+ // Maybe we should remove this for semi-join ?
+ if (!joinType.projectsRight()) {
+ // REVIEW jvs 9-Apr-2006: Just for now...
+ return planner.getCostFactory().makeTinyCost();
+ }
+ double rowCount = mq.getRowCount(this);
+ return planner.getCostFactory().makeCost(rowCount, 0, 0);
+ }
+
+ /** @deprecated Use {@link RelMdUtil#getJoinRowCount(RelMetadataQuery, Join, RexNode)}. */
+ @Deprecated // to be removed before 2.0
+ public static double estimateJoinedRows(
+ Join joinRel,
+ RexNode condition) {
+ final RelMetadataQuery mq = RelMetadataQuery.instance();
+ return Util.first(RelMdUtil.getJoinRowCount(mq, joinRel, condition), 1D);
+ }
+
+ @Override public double estimateRowCount(RelMetadataQuery mq) {
+ return Util.first(RelMdUtil.getJoinRowCount(mq, this, condition), 1D);
+ }
+
+ @Override public Set<CorrelationId> getVariablesSet() {
+ return variablesSet;
+ }
+
+ @Override public RelWriter explainTerms(RelWriter pw) {
+ return super.explainTerms(pw)
+ .item("condition", condition)
+ .item("joinType", joinType.lowerName)
+ .itemIf(
+ "systemFields",
+ getSystemFieldList(),
+ !getSystemFieldList().isEmpty());
+ }
+
+ @Override protected RelDataType deriveRowType() {
+ assert getSystemFieldList() != null;
+ RelDataType leftType = left.getRowType();
+ RelDataType rightType = right.getRowType();
+ RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
+ switch (joinType) {
+ case LEFT:
+ rightType = typeFactory.createTypeWithNullability(rightType, true);
+ break;
+ case RIGHT:
+ leftType = typeFactory.createTypeWithNullability(leftType, true);
+ break;
+ case FULL:
+ leftType = typeFactory.createTypeWithNullability(leftType, true);
+ rightType = typeFactory.createTypeWithNullability(rightType, true);
+ break;
+ case SEMI:
+ case ANTI:
+ rightType = null;
+ default:
+ break;
+ }
+ return createJoinType(typeFactory, leftType, rightType, null, getSystemFieldList());
+ }
+
+ /**
+ * Returns whether this LogicalJoin has already spawned a
+ * {@code SemiJoin} via
+ * {@link org.apache.calcite.rel.rules.JoinAddRedundantSemiJoinRule}.
+ *
+ * <p>The base implementation returns false.</p>
+ *
+ * @return whether this join has already spawned a semi join
+ */
+ public boolean isSemiJoinDone() {
+ return false;
+ }
+
+ /**
+ * Returns whether this Join is a semijoin.
+ *
+ * @return true if this Join's join type is semi.
+ */
+ public boolean isSemiJoin() {
+ return joinType == JoinRelType.SEMI;
+ }
+
+ /**
+ * Returns a list of system fields that will be prefixed to
+ * output row type.
+ *
+ * @return list of system fields
+ */
+ public List<RelDataTypeField> getSystemFieldList() {
+ return Collections.emptyList();
+ }
+
+ @Deprecated // to be removed before 2.0
+ public static RelDataType deriveJoinRowType(
+ RelDataType leftType,
+ RelDataType rightType,
+ JoinRelType joinType,
+ RelDataTypeFactory typeFactory,
+ List<String> fieldNameList,
+ List<RelDataTypeField> systemFieldList) {
+ return SqlValidatorUtil.deriveJoinRowType(leftType, rightType, joinType,
+ typeFactory, fieldNameList, systemFieldList);
+ }
+
+ @Deprecated // to be removed before 2.0
+ public static RelDataType createJoinType(
+ RelDataTypeFactory typeFactory,
+ RelDataType leftType,
+ RelDataType rightType,
+ List<String> fieldNameList,
+ List<RelDataTypeField> systemFieldList) {
+ return SqlValidatorUtil.createJoinType(typeFactory, leftType, rightType,
+ fieldNameList, systemFieldList);
+ }
+
+ @Override public final Join copy(RelTraitSet traitSet, List<RelNode> inputs) {
+ assert inputs.size() == 2;
+ return copy(traitSet, getCondition(), inputs.get(0), inputs.get(1),
+ joinType, isSemiJoinDone());
+ }
+
+ /**
+ * Creates a copy of this join, overriding condition, system fields and
+ * inputs.
+ *
+ * <p>General contract as {@link RelNode#copy}.
+ *
+ * @param traitSet Traits
+ * @param conditionExpr Condition
+ * @param left Left input
+ * @param right Right input
+ * @param joinType Join type
+ * @param semiJoinDone Whether this join has been translated to a
+ * semi-join
+ * @return Copy of this join
+ */
+ public abstract Join copy(RelTraitSet traitSet, RexNode conditionExpr,
+ RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone);
+
+ /**
+ * Analyzes the join condition.
+ *
+ * @return Analyzed join condition
+ */
+ public JoinInfo analyzeCondition() {
+ return joinInfo;
+ }
+}
+
+// End Join.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/JoinRelType.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/JoinRelType.java
new file mode 100644
index 0000000..8e69d15
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/rel/core/JoinRelType.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.core;
+
+import org.apache.calcite.linq4j.CorrelateJoinType;
+
+import java.util.Locale;
+
+/**
+ * This class is copied from https://github.com/apache/calcite/pull/1157 to supports SEMI/ANTI join.
+ * NOTES: This file should be deleted when upgrading to a new calcite version
+ * which contains CALCITE-2969.
+ */
+
+/**
+ * Enumeration of join types.
+ */
+public enum JoinRelType {
+ /**
+ * Inner join.
+ */
+ INNER,
+
+ /**
+ * Left-outer join.
+ */
+ LEFT,
+
+ /**
+ * Right-outer join.
+ */
+ RIGHT,
+
+ /**
+ * Full-outer join.
+ */
+ FULL,
+
+ /**
+ * Semi-join.
+ *
+ * <p>For example, {@code EMP semi-join DEPT} finds all {@code EMP} records
+ * that have a corresponding {@code DEPT} record:
+ *
+ * <blockquote><pre>
+ * SELECT * FROM EMP
+ * WHERE EXISTS (SELECT 1 FROM DEPT
+ * WHERE DEPT.DEPTNO = EMP.DEPTNO)</pre>
+ * </blockquote>
+ */
+ SEMI,
+
+ /**
+ * Anti-join.
+ *
+ * <p>For example, {@code EMP anti-join DEPT} finds all {@code EMP} records
+ * that do not have a corresponding {@code DEPT} record:
+ *
+ * <blockquote><pre>
+ * SELECT * FROM EMP
+ * WHERE NOT EXISTS (SELECT 1 FROM DEPT
+ * WHERE DEPT.DEPTNO = EMP.DEPTNO)</pre>
+ * </blockquote>
+ */
+ ANTI;
+
+ /** Lower-case name. */
+ public final String lowerName = name().toLowerCase(Locale.ROOT);
+
+ /**
+ * Returns whether a join of this type may generate NULL values on the
+ * right-hand side.
+ */
+ public boolean generatesNullsOnRight() {
+ return (this == LEFT) || (this == FULL);
+ }
+
+ /**
+ * Returns whether a join of this type may generate NULL values on the
+ * left-hand side.
+ */
+ public boolean generatesNullsOnLeft() {
+ return (this == RIGHT) || (this == FULL);
+ }
+
+ /**
+ * Swaps left to right, and vice versa.
+ */
+ public JoinRelType swap() {
+ switch (this) {
+ case LEFT:
+ return RIGHT;
+ case RIGHT:
+ return LEFT;
+ default:
+ return this;
+ }
+ }
+
+ /** Returns whether this join type generates nulls on side #{@code i}. */
+ public boolean generatesNullsOn(int i) {
+ switch (i) {
+ case 0:
+ return generatesNullsOnLeft();
+ case 1:
+ return generatesNullsOnRight();
+ default:
+ throw new IllegalArgumentException("invalid: " + i);
+ }
+ }
+
+ /** Returns a join type similar to this but that does not generate nulls on
+ * the left. */
+ public JoinRelType cancelNullsOnLeft() {
+ switch (this) {
+ case RIGHT:
+ return INNER;
+ case FULL:
+ return LEFT;
+ default:
+ return this;
+ }
+ }
+
+ /** Returns a join type similar to this but that does not generate nulls on
+ * the right. */
+ public JoinRelType cancelNullsOnRight() {
+ switch (this) {
+ case LEFT:
+ return INNER;
+ case FULL:
+ return RIGHT;
+ default:
+ return this;
+ }
+ }
+
+ /** Transform this JoinRelType to CorrelateJoinType. **/
+ public CorrelateJoinType toLinq4j() {
+ switch (this) {
+ case INNER:
+ return CorrelateJoinType.INNER;
+ case LEFT:
+ return CorrelateJoinType.LEFT;
+ case SEMI:
+ return CorrelateJoinType.SEMI;
+ case ANTI:
+ return CorrelateJoinType.ANTI;
+ }
+ throw new IllegalStateException(
+ "Unable to convert " + this + " to CorrelateJoinType");
+ }
+
+ public boolean projectsRight() {
+ switch (this) {
+ case INNER:
+ case LEFT:
+ case RIGHT:
+ case FULL:
+ return true;
+ case SEMI:
+ case ANTI:
+ return false;
+ }
+ throw new IllegalStateException(
+ "Unable to convert " + this + " to JoinRelType");
+ }
+}
+
+// End JoinRelType.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
new file mode 100644
index 0000000..cef2de0
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -0,0 +1,2834 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql2rel;
+
+import org.apache.flink.table.plan.rules.logical.FlinkFilterJoinRule;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSortedMap;
+import com.google.common.collect.ImmutableSortedSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.MultimapBuilder;
+import com.google.common.collect.Sets;
+import com.google.common.collect.SortedSetMultimap;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.plan.Context;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptCostImpl;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.hep.HepPlanner;
+import org.apache.calcite.plan.hep.HepProgram;
+import org.apache.calcite.plan.hep.HepRelVertex;
+import org.apache.calcite.rel.BiRel;
+import org.apache.calcite.rel.RelCollation;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelShuttleImpl;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Correlate;
+import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.core.Values;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.logical.LogicalCorrelate;
+import org.apache.calcite.rel.logical.LogicalFilter;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.logical.LogicalProject;
+import org.apache.calcite.rel.logical.LogicalSnapshot;
+import org.apache.calcite.rel.logical.LogicalSort;
+import org.apache.calcite.rel.metadata.RelMdUtil;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.rules.FilterCorrelateRule;
+import org.apache.calcite.rel.rules.FilterJoinRule;
+import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexSubQuery;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.sql.SemiJoinType;
+import org.apache.calcite.sql.SqlExplainFormat;
+import org.apache.calcite.sql.SqlExplainLevel;
+import org.apache.calcite.sql.SqlFunction;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.Holder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.ReflectUtil;
+import org.apache.calcite.util.ReflectiveVisitor;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mappings;
+import org.apache.calcite.util.trace.CalciteTrace;
+import org.slf4j.Logger;
+
+import javax.annotation.Nonnull;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+/**
+ * This class is copied from Apache Calcite except that it supports SEMI/ANTI join.
+ * NOTES: This file should be deleted when upgrading to a new calcite version which contains CALCITE-2969.
+ */
+
+/**
+ * RelDecorrelator replaces all correlated expressions (corExp) in a relational
+ * expression (RelNode) tree with non-correlated expressions that are produced
+ * from joining the RelNode that produces the corExp with the RelNode that
+ * references it.
+ *
+ * <p>TODO:</p>
+ * <ul>
+ * <li>replace {@code CorelMap} constructor parameter with a RelNode
+ * <li>make {@link #currentRel} immutable (would require a fresh
+ * RelDecorrelator for each node being decorrelated)</li>
+ * <li>make fields of {@code CorelMap} immutable</li>
+ * <li>make sub-class rules static, and have them create their own
+ * de-correlator</li>
+ * </ul>
+ */
+public class RelDecorrelator implements ReflectiveVisitor {
+ //~ Static fields/initializers ---------------------------------------------
+
+ private static final Logger SQL2REL_LOGGER =
+ CalciteTrace.getSqlToRelTracer();
+
+ //~ Instance fields --------------------------------------------------------
+
+ private final RelBuilder relBuilder;
+
+ // map built during translation
+ private CorelMap cm;
+
+ private final ReflectUtil.MethodDispatcher<Frame> dispatcher =
+ ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel",
+ RelNode.class);
+
+ // The rel which is being visited
+ private RelNode currentRel;
+
+ private final Context context;
+
+ /** Built during decorrelation, of rel to all the newly created correlated
+ * variables in its output, and to map old input positions to new input
+ * positions. This is from the view point of the parent rel of a new rel. */
+ private final Map<RelNode, Frame> map = new HashMap<>();
+
+ private final HashSet<LogicalCorrelate> generatedCorRels = new HashSet<>();
+
+ //~ Constructors -----------------------------------------------------------
+
+ private RelDecorrelator(
+ CorelMap cm,
+ Context context,
+ RelBuilder relBuilder) {
+ this.cm = cm;
+ this.context = context;
+ this.relBuilder = relBuilder;
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ @Deprecated // to be removed before 2.0
+ public static RelNode decorrelateQuery(RelNode rootRel) {
+ final RelBuilder relBuilder =
+ RelFactories.LOGICAL_BUILDER.create(rootRel.getCluster(), null);
+ return decorrelateQuery(rootRel, relBuilder);
+ }
+
+ /** Decorrelates a query.
+ *
+ * <p>This is the main entry point to {@code RelDecorrelator}.
+ *
+ * @param rootRel Root node of the query
+ * @param relBuilder Builder for relational expressions
+ *
+ * @return Equivalent query with all
+ * {@link org.apache.calcite.rel.logical.LogicalCorrelate} instances removed
+ */
+ public static RelNode decorrelateQuery(RelNode rootRel,
+ RelBuilder relBuilder) {
+ final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
+ if (!corelMap.hasCorrelation()) {
+ return rootRel;
+ }
+
+ final RelOptCluster cluster = rootRel.getCluster();
+ final RelDecorrelator decorrelator =
+ new RelDecorrelator(corelMap,
+ cluster.getPlanner().getContext(), relBuilder);
+
+ RelNode newRootRel = decorrelator.removeCorrelationViaRule(rootRel);
+
+ if (SQL2REL_LOGGER.isDebugEnabled()) {
+ SQL2REL_LOGGER.debug(
+ RelOptUtil.dumpPlan("Plan after removing Correlator", newRootRel,
+ SqlExplainFormat.TEXT, SqlExplainLevel.EXPPLAN_ATTRIBUTES));
+ }
+
+ if (!decorrelator.cm.mapCorToCorRel.isEmpty()) {
+ newRootRel = decorrelator.decorrelate(newRootRel);
+ }
+
+ return newRootRel;
+ }
+
+ private void setCurrent(RelNode root, LogicalCorrelate corRel) {
+ currentRel = corRel;
+ if (corRel != null) {
+ cm = new CorelMapBuilder().build(Util.first(root, corRel));
+ }
+ }
+
+ private RelBuilderFactory relBuilderFactory() {
+ return RelBuilder.proto(relBuilder);
+ }
+
+ private RelNode decorrelate(RelNode root) {
+ // first adjust count() expression if any
+ final RelBuilderFactory f = relBuilderFactory();
+ HepProgram program = HepProgram.builder()
+ .addRuleInstance(new AdjustProjectForCountAggregateRule(false, f))
+ .addRuleInstance(new AdjustProjectForCountAggregateRule(true, f))
+ .addRuleInstance(
+ new FilterJoinRule.FilterIntoJoinRule(true, f,
+ FilterJoinRule.TRUE_PREDICATE))
+ .addRuleInstance(
+ new FilterProjectTransposeRule(Filter.class, Project.class, true,
+ true, f))
+ .addRuleInstance(new FilterCorrelateRule(f))
+ .build();
+
+ HepPlanner planner = createPlanner(program);
+
+ planner.setRoot(root);
+ root = planner.findBestExp();
+
+ // Perform decorrelation.
+ map.clear();
+
+ final Frame frame = getInvoke(root, null);
+ if (frame != null) {
+ // has been rewritten; apply rules post-decorrelation
+ final HepProgram program2 = HepProgram.builder()
+ .addRuleInstance(
+ new FlinkFilterJoinRule.FlinkFilterIntoJoinRule(
+ true, f,
+ FlinkFilterJoinRule.TRUE_PREDICATE))
+ .addRuleInstance(
+ new FlinkFilterJoinRule.FlinkJoinConditionPushRule(
+ f,
+ FlinkFilterJoinRule.TRUE_PREDICATE))
+ .build();
+
+ final HepPlanner planner2 = createPlanner(program2);
+ final RelNode newRoot = frame.r;
+ planner2.setRoot(newRoot);
+ return planner2.findBestExp();
+ }
+
+ return root;
+ }
+
+ private Function2<RelNode, RelNode, Void> createCopyHook() {
+ return (oldNode, newNode) -> {
+ if (cm.mapRefRelToCorRef.containsKey(oldNode)) {
+ cm.mapRefRelToCorRef.putAll(newNode,
+ cm.mapRefRelToCorRef.get(oldNode));
+ }
+ if (oldNode instanceof LogicalCorrelate
+ && newNode instanceof LogicalCorrelate) {
+ LogicalCorrelate oldCor = (LogicalCorrelate) oldNode;
+ CorrelationId c = oldCor.getCorrelationId();
+ if (cm.mapCorToCorRel.get(c) == oldNode) {
+ cm.mapCorToCorRel.put(c, newNode);
+ }
+
+ if (generatedCorRels.contains(oldNode)) {
+ generatedCorRels.add((LogicalCorrelate) newNode);
+ }
+ }
+ return null;
+ };
+ }
+
+ private HepPlanner createPlanner(HepProgram program) {
+ // Create a planner with a hook to update the mapping tables when a
+ // node is copied when it is registered.
+ return new HepPlanner(
+ program,
+ context,
+ true,
+ createCopyHook(),
+ RelOptCostImpl.FACTORY);
+ }
+
+ public RelNode removeCorrelationViaRule(RelNode root) {
+ final RelBuilderFactory f = relBuilderFactory();
+ HepProgram program = HepProgram.builder()
+ .addRuleInstance(new RemoveSingleAggregateRule(f))
+ .addRuleInstance(new RemoveCorrelationForScalarProjectRule(f))
+ .addRuleInstance(new RemoveCorrelationForScalarAggregateRule(f))
+ .build();
+
+ HepPlanner planner = createPlanner(program);
+
+ planner.setRoot(root);
+ return planner.findBestExp();
+ }
+
+ protected RexNode decorrelateExpr(RelNode currentRel,
+ Map<RelNode, Frame> map, CorelMap cm, RexNode exp) {
+ DecorrelateRexShuttle shuttle =
+ new DecorrelateRexShuttle(currentRel, map, cm);
+ return exp.accept(shuttle);
+ }
+
+ protected RexNode removeCorrelationExpr(
+ RexNode exp,
+ boolean projectPulledAboveLeftCorrelator) {
+ RemoveCorrelationRexShuttle shuttle =
+ new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
+ projectPulledAboveLeftCorrelator, null, ImmutableSet.of());
+ return exp.accept(shuttle);
+ }
+
+ protected RexNode removeCorrelationExpr(
+ RexNode exp,
+ boolean projectPulledAboveLeftCorrelator,
+ RexInputRef nullIndicator) {
+ RemoveCorrelationRexShuttle shuttle =
+ new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
+ projectPulledAboveLeftCorrelator, nullIndicator,
+ ImmutableSet.of());
+ return exp.accept(shuttle);
+ }
+
+ protected RexNode removeCorrelationExpr(
+ RexNode exp,
+ boolean projectPulledAboveLeftCorrelator,
+ Set<Integer> isCount) {
+ RemoveCorrelationRexShuttle shuttle =
+ new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
+ projectPulledAboveLeftCorrelator, null, isCount);
+ return exp.accept(shuttle);
+ }
+
+ /** Fallback if none of the other {@code decorrelateRel} methods match. */
+ public Frame decorrelateRel(RelNode rel) {
+ RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());
+
+ if (rel.getInputs().size() > 0) {
+ List<RelNode> oldInputs = rel.getInputs();
+ List<RelNode> newInputs = new ArrayList<>();
+ for (int i = 0; i < oldInputs.size(); ++i) {
+ final Frame frame = getInvoke(oldInputs.get(i), rel);
+ if (frame == null || !frame.corDefOutputs.isEmpty()) {
+ // if input is not rewritten, or if it produces correlated
+ // variables, terminate rewrite
+ return null;
+ }
+ newInputs.add(frame.r);
+ newRel.replaceInput(i, frame.r);
+ }
+
+ if (!Util.equalShallow(oldInputs, newInputs)) {
+ newRel = rel.copy(rel.getTraitSet(), newInputs);
+ }
+ }
+
+ // the output position should not change since there are no corVars
+ // coming from below.
+ return register(rel, newRel, identityMap(rel.getRowType().getFieldCount()),
+ ImmutableSortedMap.of());
+ }
+
+ /**
+ * Rewrite Sort.
+ *
+ * @param rel Sort to be rewritten
+ */
+ public Frame decorrelateRel(Sort rel) {
+ //
+ // Rewrite logic:
+ //
+ // 1. change the collations field to reference the new input.
+ //
+
+ // Sort itself should not reference corVars.
+ assert !cm.mapRefRelToCorRef.containsKey(rel);
+
+ // Sort only references field positions in collations field.
+ // The collations field in the newRel now need to refer to the
+ // new output positions in its input.
+ // Its output does not change the input ordering, so there's no
+ // need to call propagateExpr.
+
+ final RelNode oldInput = rel.getInput();
+ final Frame frame = getInvoke(oldInput, rel);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+ final RelNode newInput = frame.r;
+
+ Mappings.TargetMapping mapping =
+ Mappings.target(frame.oldToNewOutputs,
+ oldInput.getRowType().getFieldCount(),
+ newInput.getRowType().getFieldCount());
+
+ RelCollation oldCollation = rel.getCollation();
+ RelCollation newCollation = RexUtil.apply(mapping, oldCollation);
+
+ final Sort newSort =
+ LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
+
+ // Sort does not change input ordering
+ return register(rel, newSort, frame.oldToNewOutputs, frame.corDefOutputs);
+ }
+
+ /**
+ * Rewrites a {@link Values}.
+ *
+ * @param rel Values to be rewritten
+ */
+ public Frame decorrelateRel(Values rel) {
+ // There are no inputs, so rel does not need to be changed.
+ return null;
+ }
+
+ /**
+ * Rewrites a {@link LogicalAggregate}.
+ *
+ * @param rel Aggregate to rewrite
+ */
+ public Frame decorrelateRel(LogicalAggregate rel) {
+ //
+ // Rewrite logic:
+ //
+ // 1. Permute the group by keys to the front.
+ // 2. If the input of an aggregate produces correlated variables,
+ // add them to the group list.
+ // 3. Change aggCalls to reference the new project.
+ //
+
+ // Aggregate itself should not reference corVars.
+ assert !cm.mapRefRelToCorRef.containsKey(rel);
+
+ final RelNode oldInput = rel.getInput();
+ final Frame frame = getInvoke(oldInput, rel);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+ final RelNode newInput = frame.r;
+
+ // map from newInput
+ Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
+ final int oldGroupKeyCount = rel.getGroupSet().cardinality();
+
+ // Project projects the original expressions,
+ // plus any correlated variables the input wants to pass along.
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+
+ List<RelDataTypeField> newInputOutput =
+ newInput.getRowType().getFieldList();
+
+ int newPos = 0;
+
+ // oldInput has the original group by keys in the front.
+ final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
+ for (int i = 0; i < oldGroupKeyCount; i++) {
+ final RexLiteral constant = projectedLiteral(newInput, i);
+ if (constant != null) {
+ // Exclude constants. Aggregate({true}) occurs because Aggregate({})
+ // would generate 1 row even when applied to an empty table.
+ omittedConstants.put(i, constant);
+ continue;
+ }
+ int newInputPos = frame.oldToNewOutputs.get(i);
+ projects.add(RexInputRef.of2(newInputPos, newInputOutput));
+ mapNewInputToProjOutputs.put(newInputPos, newPos);
+ newPos++;
+ }
+
+ final SortedMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
+ if (!frame.corDefOutputs.isEmpty()) {
+ // If input produces correlated variables, move them to the front,
+ // right after any existing GROUP BY fields.
+
+ // Now add the corVars from the input, starting from
+ // position oldGroupKeyCount.
+ for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
+ projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
+
+ corDefOutputs.put(entry.getKey(), newPos);
+ mapNewInputToProjOutputs.put(entry.getValue(), newPos);
+ newPos++;
+ }
+ }
+
+ // add the remaining fields
+ final int newGroupKeyCount = newPos;
+ for (int i = 0; i < newInputOutput.size(); i++) {
+ if (!mapNewInputToProjOutputs.containsKey(i)) {
+ projects.add(RexInputRef.of2(i, newInputOutput));
+ mapNewInputToProjOutputs.put(i, newPos);
+ newPos++;
+ }
+ }
+
+ assert newPos == newInputOutput.size();
+
+ // This Project will be what the old input maps to,
+ // replacing any previous mapping from old input).
+ RelNode newProject = relBuilder.push(newInput)
+ .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .build();
+
+ // update mappings:
+ // oldInput ----> newInput
+ //
+ // newProject
+ // |
+ // oldInput ----> newInput
+ //
+ // is transformed to
+ //
+ // oldInput ----> newProject
+ // |
+ // newInput
+ Map<Integer, Integer> combinedMap = new HashMap<>();
+
+ for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
+ combinedMap.put(oldInputPos,
+ mapNewInputToProjOutputs.get(
+ frame.oldToNewOutputs.get(oldInputPos)));
+ }
+
+ register(oldInput, newProject, combinedMap, corDefOutputs);
+
+ // now it's time to rewrite the Aggregate
+ final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
+ List<AggregateCall> newAggCalls = new ArrayList<>();
+ List<AggregateCall> oldAggCalls = rel.getAggCallList();
+
+ ImmutableList<ImmutableBitSet> newGroupSets = null;
+ if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
+ final ImmutableBitSet addedGroupSet =
+ ImmutableBitSet.range(oldGroupKeyCount, newGroupKeyCount);
+ final Iterable<ImmutableBitSet> tmpGroupSets =
+ Iterables.transform(rel.getGroupSets(),
+ bitSet -> bitSet.union(addedGroupSet));
+ newGroupSets = ImmutableBitSet.ORDERING.immutableSortedCopy(tmpGroupSets);
+ }
+
+ int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
+ int newInputOutputFieldCount = newGroupSet.cardinality();
+
+ int i = -1;
+ for (AggregateCall oldAggCall : oldAggCalls) {
+ ++i;
+ List<Integer> oldAggArgs = oldAggCall.getArgList();
+
+ List<Integer> aggArgs = new ArrayList<>();
+
+ // Adjust the Aggregate argument positions.
+ // Note Aggregate does not change input ordering, so the input
+ // output position mapping can be used to derive the new positions
+ // for the argument.
+ for (int oldPos : oldAggArgs) {
+ aggArgs.add(combinedMap.get(oldPos));
+ }
+ final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg
+ : combinedMap.get(oldAggCall.filterArg);
+
+ newAggCalls.add(
+ oldAggCall.adaptTo(newProject, aggArgs, filterArg,
+ oldGroupKeyCount, newGroupKeyCount));
+
+ // The old to new output position mapping will be the same as that
+ // of newProject, plus any aggregates that the oldAgg produces.
+ combinedMap.put(
+ oldInputOutputFieldCount + i,
+ newInputOutputFieldCount + i);
+ }
+
+ relBuilder.push(
+ LogicalAggregate.create(newProject, newGroupSet, newGroupSets, newAggCalls));
+
+ if (!omittedConstants.isEmpty()) {
+ final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
+ for (Map.Entry<Integer, RexLiteral> entry
+ : omittedConstants.descendingMap().entrySet()) {
+ postProjects.add(entry.getKey() + frame.corDefOutputs.size(),
+ entry.getValue());
+ }
+ relBuilder.project(postProjects);
+ }
+
+ // Aggregate does not change input ordering so corVars will be
+ // located at the same position as the input newProject.
+ return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
+ }
+
+ public Frame getInvoke(RelNode r, RelNode parent) {
+ final Frame frame = dispatcher.invoke(r);
+ if (frame != null) {
+ map.put(r, frame);
+ }
+ currentRel = parent;
+ return frame;
+ }
+
+ /** Returns a literal output field, or null if it is not literal. */
+ private static RexLiteral projectedLiteral(RelNode rel, int i) {
+ if (rel instanceof Project) {
+ final Project project = (Project) rel;
+ final RexNode node = project.getProjects().get(i);
+ if (node instanceof RexLiteral) {
+ return (RexLiteral) node;
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Rewrite LogicalProject.
+ *
+ * @param rel the project rel to rewrite
+ */
+ public Frame decorrelateRel(LogicalProject rel) {
+ //
+ // Rewrite logic:
+ //
+ // 1. Pass along any correlated variables coming from the input.
+ //
+
+ final RelNode oldInput = rel.getInput();
+ Frame frame = getInvoke(oldInput, rel);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+ final List<RexNode> oldProjects = rel.getProjects();
+ final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
+
+ // Project projects the original expressions,
+ // plus any correlated variables the input wants to pass along.
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+
+ // If this Project has correlated reference, create value generator
+ // and produce the correlated variables in the new output.
+ if (cm.mapRefRelToCorRef.containsKey(rel)) {
+ frame = decorrelateInputWithValueGenerator(rel, frame);
+ }
+
+ // Project projects the original expressions
+ final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+ int newPos;
+ for (newPos = 0; newPos < oldProjects.size(); newPos++) {
+ projects.add(
+ newPos,
+ Pair.of(
+ decorrelateExpr(currentRel, map, cm, oldProjects.get(newPos)),
+ relOutput.get(newPos).getName()));
+ mapOldToNewOutputs.put(newPos, newPos);
+ }
+
+ // Project any correlated variables the input wants to pass along.
+ final SortedMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
+ for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
+ projects.add(
+ RexInputRef.of2(entry.getValue(),
+ frame.r.getRowType().getFieldList()));
+ corDefOutputs.put(entry.getKey(), newPos);
+ newPos++;
+ }
+
+ RelNode newProject = relBuilder.push(frame.r)
+ .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .build();
+
+ return register(rel, newProject, mapOldToNewOutputs, corDefOutputs);
+ }
+
+ /**
+ * Create RelNode tree that produces a list of correlated variables.
+ *
+ * @param correlations correlated variables to generate
+ * @param valueGenFieldOffset offset in the output that generated columns
+ * will start
+ * @param corDefOutputs output positions for the correlated variables
+ * generated
+ * @return RelNode the root of the resultant RelNode tree
+ */
+ private RelNode createValueGenerator(
+ Iterable<CorRef> correlations,
+ int valueGenFieldOffset,
+ SortedMap<CorDef, Integer> corDefOutputs) {
+ final Map<RelNode, List<Integer>> mapNewInputToOutputs = new HashMap<>();
+
+ final Map<RelNode, Integer> mapNewInputToNewOffset = new HashMap<>();
+
+ // Input provides the definition of a correlated variable.
+ // Add to map all the referenced positions (relative to each input rel).
+ for (CorRef corVar : correlations) {
+ final int oldCorVarOffset = corVar.field;
+
+ final RelNode oldInput = getCorRel(corVar);
+ assert oldInput != null;
+ final Frame frame = getFrame(oldInput, true);
+ assert frame != null;
+ final RelNode newInput = frame.r;
+
+ final List<Integer> newLocalOutputs;
+ if (!mapNewInputToOutputs.containsKey(newInput)) {
+ newLocalOutputs = new ArrayList<>();
+ } else {
+ newLocalOutputs = mapNewInputToOutputs.get(newInput);
+ }
+
+ final int newCorVarOffset = frame.oldToNewOutputs.get(oldCorVarOffset);
+
+ // Add all unique positions referenced.
+ if (!newLocalOutputs.contains(newCorVarOffset)) {
+ newLocalOutputs.add(newCorVarOffset);
+ }
+ mapNewInputToOutputs.put(newInput, newLocalOutputs);
+ }
+
+ int offset = 0;
+
+ // Project only the correlated fields out of each input
+ // and join the project together.
+ // To make sure the plan does not change in terms of join order,
+ // join these rels based on their occurrence in corVar list which
+ // is sorted.
+ final Set<RelNode> joinedInputs = new HashSet<>();
+
+ RelNode r = null;
+ for (CorRef corVar : correlations) {
+ final RelNode oldInput = getCorRel(corVar);
+ assert oldInput != null;
+ final RelNode newInput = getFrame(oldInput, true).r;
+ assert newInput != null;
+
+ if (!joinedInputs.contains(newInput)) {
+ RelNode project =
+ RelOptUtil.createProject(newInput,
+ mapNewInputToOutputs.get(newInput));
+ RelNode distinct = relBuilder.push(project)
+ .distinct()
+ .build();
+ RelOptCluster cluster = distinct.getCluster();
+
+ joinedInputs.add(newInput);
+ mapNewInputToNewOffset.put(newInput, offset);
+ offset += distinct.getRowType().getFieldCount();
+
+ if (r == null) {
+ r = distinct;
+ } else {
+ r =
+ LogicalJoin.create(r, distinct,
+ cluster.getRexBuilder().makeLiteral(true),
+ ImmutableSet.of(), JoinRelType.INNER);
+ }
+ }
+ }
+
+ // Translate the positions of correlated variables to be relative to
+ // the join output, leaving room for valueGenFieldOffset because
+ // valueGenerators are joined with the original left input of the rel
+ // referencing correlated variables.
+ for (CorRef corRef : correlations) {
+ // The first input of a Correlate is always the rel defining
+ // the correlated variables.
+ final RelNode oldInput = getCorRel(corRef);
+ assert oldInput != null;
+ final Frame frame = getFrame(oldInput, true);
+ final RelNode newInput = frame.r;
+ assert newInput != null;
+
+ final List<Integer> newLocalOutputs = mapNewInputToOutputs.get(newInput);
+
+ final int newLocalOutput = frame.oldToNewOutputs.get(corRef.field);
+
+ // newOutput is the index of the corVar in the referenced
+ // position list plus the offset of referenced position list of
+ // each newInput.
+ final int newOutput =
+ newLocalOutputs.indexOf(newLocalOutput)
+ + mapNewInputToNewOffset.get(newInput)
+ + valueGenFieldOffset;
+
+ corDefOutputs.put(corRef.def(), newOutput);
+ }
+
+ return r;
+ }
+
+ private Frame getFrame(RelNode r, boolean safe) {
+ final Frame frame = map.get(r);
+ if (frame == null && safe) {
+ return new Frame(r, r, ImmutableSortedMap.of(),
+ identityMap(r.getRowType().getFieldCount()));
+ }
+ return frame;
+ }
+
+ private RelNode getCorRel(CorRef corVar) {
+ final RelNode r = cm.mapCorToCorRel.get(corVar.corr);
+ return r.getInput(0);
+ }
+
+ /** Adds a value generator to satisfy the correlating variables used by
+ * a relational expression, if those variables are not already provided by
+ * its input. */
+ private Frame maybeAddValueGenerator(RelNode rel, Frame frame) {
+ final CorelMap cm1 = new CorelMapBuilder().build(frame.r, rel);
+ if (!cm1.mapRefRelToCorRef.containsKey(rel)) {
+ return frame;
+ }
+ final Collection<CorRef> needs = cm1.mapRefRelToCorRef.get(rel);
+ final ImmutableSortedSet<CorDef> haves = frame.corDefOutputs.keySet();
+ if (hasAll(needs, haves)) {
+ return frame;
+ }
+ return decorrelateInputWithValueGenerator(rel, frame);
+ }
+
+ /** Returns whether all of a collection of {@link CorRef}s are satisfied
+ * by at least one of a collection of {@link CorDef}s. */
+ private boolean hasAll(Collection<CorRef> corRefs,
+ Collection<CorDef> corDefs) {
+ for (CorRef corRef : corRefs) {
+ if (!has(corDefs, corRef)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /** Returns whether a {@link CorrelationId} is satisfied by at least one of a
+ * collection of {@link CorDef}s. */
+ private boolean has(Collection<CorDef> corDefs, CorRef corr) {
+ for (CorDef corDef : corDefs) {
+ if (corDef.corr.equals(corr.corr) && corDef.field == corr.field) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
+ // currently only handles one input
+ assert rel.getInputs().size() == 1;
+ RelNode oldInput = frame.r;
+
+ final SortedMap<CorDef, Integer> corDefOutputs =
+ new TreeMap<>(frame.corDefOutputs);
+
+ final Collection<CorRef> corVarList = cm.mapRefRelToCorRef.get(rel);
+
+ // Try to populate correlation variables using local fields.
+ // This means that we do not need a value generator.
+ if (rel instanceof Filter) {
+ SortedMap<CorDef, Integer> map = new TreeMap<>();
+ List<RexNode> projects = new ArrayList<>();
+ for (CorRef correlation : corVarList) {
+ final CorDef def = correlation.def();
+ if (corDefOutputs.containsKey(def) || map.containsKey(def)) {
+ continue;
+ }
+ try {
+ findCorrelationEquivalent(correlation, ((Filter) rel).getCondition());
+ } catch (Util.FoundOne e) {
+ if (e.getNode() instanceof RexInputRef) {
+ map.put(def, ((RexInputRef) e.getNode()).getIndex());
+ } else {
+ map.put(def,
+ frame.r.getRowType().getFieldCount() + projects.size());
+ projects.add((RexNode) e.getNode());
+ }
+ }
+ }
+ // If all correlation variables are now satisfied, skip creating a value
+ // generator.
+ if (map.size() == corVarList.size()) {
+ map.putAll(frame.corDefOutputs);
+ final RelNode r;
+ if (!projects.isEmpty()) {
+ relBuilder.push(oldInput)
+ .project(Iterables.concat(relBuilder.fields(), projects));
+ r = relBuilder.build();
+ } else {
+ r = oldInput;
+ }
+ return register(rel.getInput(0), r,
+ frame.oldToNewOutputs, map);
+ }
+ }
+
+ int leftInputOutputCount = frame.r.getRowType().getFieldCount();
+
+ // can directly add positions into corDefOutputs since join
+ // does not change the output ordering from the inputs.
+ RelNode valueGen =
+ createValueGenerator(corVarList, leftInputOutputCount, corDefOutputs);
+
+ RelNode join =
+ LogicalJoin.create(frame.r, valueGen, relBuilder.literal(true),
+ ImmutableSet.of(), JoinRelType.INNER);
+
+ // Join or Filter does not change the old input ordering. All
+ // input fields from newLeftInput (i.e. the original input to the old
+ // Filter) are in the output and in the same position.
+ return register(rel.getInput(0), join, frame.oldToNewOutputs,
+ corDefOutputs);
+ }
+
+ /** Finds a {@link RexInputRef} that is equivalent to a {@link CorRef},
+ * and if found, throws a {@link org.apache.calcite.util.Util.FoundOne}. */
+ private void findCorrelationEquivalent(CorRef correlation, RexNode e)
+ throws Util.FoundOne {
+ switch (e.getKind()) {
+ case EQUALS:
+ final RexCall call = (RexCall) e;
+ final List<RexNode> operands = call.getOperands();
+ if (references(operands.get(0), correlation)) {
+ throw new Util.FoundOne(operands.get(1));
+ }
+ if (references(operands.get(1), correlation)) {
+ throw new Util.FoundOne(operands.get(0));
+ }
+ break;
+ case AND:
+ for (RexNode operand : ((RexCall) e).getOperands()) {
+ findCorrelationEquivalent(correlation, operand);
+ }
+ }
+ }
+
+ private boolean references(RexNode e, CorRef correlation) {
+ switch (e.getKind()) {
+ case CAST:
+ final RexNode operand = ((RexCall) e).getOperands().get(0);
+ if (isWidening(e.getType(), operand.getType())) {
+ return references(operand, correlation);
+ }
+ return false;
+ case FIELD_ACCESS:
+ final RexFieldAccess f = (RexFieldAccess) e;
+ if (f.getField().getIndex() == correlation.field
+ && f.getReferenceExpr() instanceof RexCorrelVariable) {
+ if (((RexCorrelVariable) f.getReferenceExpr()).id == correlation.corr) {
+ return true;
+ }
+ }
+ // fall through
+ default:
+ return false;
+ }
+ }
+
+ /** Returns whether one type is just a widening of another.
+ *
+ * <p>For example:<ul>
+ * <li>{@code VARCHAR(10)} is a widening of {@code VARCHAR(5)}.
+ * <li>{@code VARCHAR(10)} is a widening of {@code VARCHAR(10) NOT NULL}.
+ * </ul>
+ */
+ private boolean isWidening(RelDataType type, RelDataType type1) {
+ return type.getSqlTypeName() == type1.getSqlTypeName()
+ && type.getPrecision() >= type1.getPrecision();
+ }
+
+ /**
+ * Rewrite LogicalSnapshot.
+ *
+ * @param rel the snapshot rel to rewrite
+ */
+ public Frame decorrelateRel(LogicalSnapshot rel) {
+ if (RexUtil.containsCorrelation(rel.getPeriod())) {
+ return null;
+ }
+ return decorrelateRel((RelNode) rel);
+ }
+
+ /**
+ * Rewrite LogicalFilter.
+ *
+ * @param rel the filter rel to rewrite
+ */
+ public Frame decorrelateRel(LogicalFilter rel) {
+ //
+ // Rewrite logic:
+ //
+ // 1. If a Filter references a correlated field in its filter
+ // condition, rewrite the Filter to be
+ // Filter
+ // Join(cross product)
+ // originalFilterInput
+ // ValueGenerator(produces distinct sets of correlated variables)
+ // and rewrite the correlated fieldAccess in the filter condition to
+ // reference the Join output.
+ //
+ // 2. If Filter does not reference correlated variables, simply
+ // rewrite the filter condition using new input.
+ //
+
+ final RelNode oldInput = rel.getInput();
+ Frame frame = getInvoke(oldInput, rel);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ // If this Filter has correlated reference, create value generator
+ // and produce the correlated variables in the new output.
+ if (false) {
+ if (cm.mapRefRelToCorRef.containsKey(rel)) {
+ frame = decorrelateInputWithValueGenerator(rel, frame);
+ }
+ } else {
+ frame = maybeAddValueGenerator(rel, frame);
+ }
+
+ final CorelMap cm2 = new CorelMapBuilder().build(rel);
+
+ // Replace the filter expression to reference output of the join
+ // Map filter to the new filter over join
+ relBuilder.push(frame.r)
+ .filter(decorrelateExpr(currentRel, map, cm2, rel.getCondition()));
+
+ // Filter does not change the input ordering.
+ // Filter rel does not permute the input.
+ // All corVars produced by filter will have the same output positions in the
+ // input rel.
+ return register(rel, relBuilder.build(), frame.oldToNewOutputs,
+ frame.corDefOutputs);
+ }
+
+ /**
+ * Rewrite Correlate into a left outer join.
+ *
+ * @param rel Correlator
+ */
+ public Frame decorrelateRel(LogicalCorrelate rel) {
+ //
+ // Rewrite logic:
+ //
+ // The original left input will be joined with the new right input that
+ // has generated correlated variables propagated up. For any generated
+ // corVars that are not used in the join key, pass them along to be
+ // joined later with the Correlates that produce them.
+ //
+
+ // the right input to Correlate should produce correlated variables
+ final RelNode oldLeft = rel.getInput(0);
+ final RelNode oldRight = rel.getInput(1);
+
+ final Frame leftFrame = getInvoke(oldLeft, rel);
+ final Frame rightFrame = getInvoke(oldRight, rel);
+
+ if (leftFrame == null || rightFrame == null) {
+ // If any input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ if (rightFrame.corDefOutputs.isEmpty()) {
+ return null;
+ }
+
+ assert rel.getRequiredColumns().cardinality()
+ <= rightFrame.corDefOutputs.keySet().size();
+
+ // Change correlator rel into a join.
+ // Join all the correlated variables produced by this correlator rel
+ // with the values generated and propagated from the right input
+ final SortedMap<CorDef, Integer> corDefOutputs =
+ new TreeMap<>(rightFrame.corDefOutputs);
+ final List<RexNode> conditions = new ArrayList<>();
+ final List<RelDataTypeField> newLeftOutput =
+ leftFrame.r.getRowType().getFieldList();
+ int newLeftFieldCount = newLeftOutput.size();
+
+ final List<RelDataTypeField> newRightOutput =
+ rightFrame.r.getRowType().getFieldList();
+
+ for (Map.Entry<CorDef, Integer> rightOutput
+ : new ArrayList<>(corDefOutputs.entrySet())) {
+ final CorDef corDef = rightOutput.getKey();
+ if (!corDef.corr.equals(rel.getCorrelationId())) {
+ continue;
+ }
+ final int newLeftPos = leftFrame.oldToNewOutputs.get(corDef.field);
+ final int newRightPos = rightOutput.getValue();
+ conditions.add(
+ relBuilder.call(SqlStdOperatorTable.EQUALS,
+ RexInputRef.of(newLeftPos, newLeftOutput),
+ new RexInputRef(newLeftFieldCount + newRightPos,
+ newRightOutput.get(newRightPos).getType())));
+
+ // remove this corVar from output position mapping
+ corDefOutputs.remove(corDef);
+ }
+
+ // Update the output position for the corVars: only pass on the cor
+ // vars that are not used in the join key.
+ for (CorDef corDef : corDefOutputs.keySet()) {
+ int newPos = corDefOutputs.get(corDef) + newLeftFieldCount;
+ corDefOutputs.put(corDef, newPos);
+ }
+
+ // then add any corVar from the left input. Do not need to change
+ // output positions.
+ corDefOutputs.putAll(leftFrame.corDefOutputs);
+
+ // Create the mapping between the output of the old correlation rel
+ // and the new join rel
+ final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+
+ int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
+
+ int oldRightFieldCount = oldRight.getRowType().getFieldCount();
+ //noinspection AssertWithSideEffects
+ assert rel.getRowType().getFieldCount()
+ == oldLeftFieldCount + oldRightFieldCount;
+
+ // Left input positions are not changed.
+ mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);
+
+ // Right input positions are shifted by newLeftFieldCount.
+ for (int i = 0; i < oldRightFieldCount; i++) {
+ mapOldToNewOutputs.put(i + oldLeftFieldCount,
+ rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount);
+ }
+
+ final RexNode condition =
+ RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions);
+ RelNode newJoin =
+ LogicalJoin.create(leftFrame.r, rightFrame.r, condition,
+ ImmutableSet.of(), toJoinRelType(rel.getJoinType()));
+
+ return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs);
+ }
+
+ /**
+ * Rewrite LogicalJoin.
+ *
+ * @param rel Join
+ */
+ public Frame decorrelateRel(LogicalJoin rel) {
+ //
+ // Rewrite logic:
+ //
+ // 1. rewrite join condition.
+ // 2. map output positions and produce corVars if any.
+ //
+
+ if (!rel.getJoinType().projectsRight()) {
+ return decorrelateRel((RelNode) rel);
+ }
+
+ final RelNode oldLeft = rel.getInput(0);
+ final RelNode oldRight = rel.getInput(1);
+
+ final Frame leftFrame = getInvoke(oldLeft, rel);
+ final Frame rightFrame = getInvoke(oldRight, rel);
+
+ if (leftFrame == null || rightFrame == null) {
+ // If any input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ final RelNode newJoin =
+ LogicalJoin.create(leftFrame.r, rightFrame.r,
+ decorrelateExpr(currentRel, map, cm, rel.getCondition()),
+ ImmutableSet.of(), rel.getJoinType());
+
+ // Create the mapping between the output of the old correlation rel
+ // and the new join rel
+ Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+
+ int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
+ int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
+
+ int oldRightFieldCount = oldRight.getRowType().getFieldCount();
+ //noinspection AssertWithSideEffects
+ assert rel.getRowType().getFieldCount()
+ == oldLeftFieldCount + oldRightFieldCount;
+
+ // Left input positions are not changed.
+ mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);
+
+ // Right input positions are shifted by newLeftFieldCount.
+ for (int i = 0; i < oldRightFieldCount; i++) {
+ mapOldToNewOutputs.put(i + oldLeftFieldCount,
+ rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount);
+ }
+
+ final SortedMap<CorDef, Integer> corDefOutputs =
+ new TreeMap<>(leftFrame.corDefOutputs);
+
+ // Right input positions are shifted by newLeftFieldCount.
+ for (Map.Entry<CorDef, Integer> entry
+ : rightFrame.corDefOutputs.entrySet()) {
+ corDefOutputs.put(entry.getKey(),
+ entry.getValue() + newLeftFieldCount);
+ }
+ return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs);
+ }
+
+ private static RexInputRef getNewForOldInputRef(RelNode currentRel,
+ Map<RelNode, Frame> map, RexInputRef oldInputRef) {
+ assert currentRel != null;
+
+ int oldOrdinal = oldInputRef.getIndex();
+ int newOrdinal = 0;
+
+ // determine which input rel oldOrdinal references, and adjust
+ // oldOrdinal to be relative to that input rel
+ RelNode oldInput = null;
+
+ for (RelNode oldInput0 : currentRel.getInputs()) {
+ RelDataType oldInputType = oldInput0.getRowType();
+ int n = oldInputType.getFieldCount();
+ if (oldOrdinal < n) {
+ oldInput = oldInput0;
+ break;
+ }
+ RelNode newInput = map.get(oldInput0).r;
+ newOrdinal += newInput.getRowType().getFieldCount();
+ oldOrdinal -= n;
+ }
+
+ assert oldInput != null;
+
+ final Frame frame = map.get(oldInput);
+ assert frame != null;
+
+ // now oldOrdinal is relative to oldInput
+ int oldLocalOrdinal = oldOrdinal;
+
+ // figure out the newLocalOrdinal, relative to the newInput.
+ int newLocalOrdinal = oldLocalOrdinal;
+
+ if (!frame.oldToNewOutputs.isEmpty()) {
+ newLocalOrdinal = frame.oldToNewOutputs.get(oldLocalOrdinal);
+ }
+
+ newOrdinal += newLocalOrdinal;
+
+ return new RexInputRef(newOrdinal,
+ frame.r.getRowType().getFieldList().get(newLocalOrdinal).getType());
+ }
+
+ /**
+ * Pulls project above the join from its RHS input. Enforces nullability
+ * for join output.
+ *
+ * @param join Join
+ * @param project Original project as the right-hand input of the join
+ * @param nullIndicatorPos Position of null indicator
+ * @return the subtree with the new Project at the root
+ */
+ private RelNode projectJoinOutputWithNullability(
+ LogicalJoin join,
+ LogicalProject project,
+ int nullIndicatorPos) {
+ final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
+ final RelNode left = join.getLeft();
+ final JoinRelType joinType = join.getJoinType();
+
+ RexInputRef nullIndicator =
+ new RexInputRef(
+ nullIndicatorPos,
+ typeFactory.createTypeWithNullability(
+ join.getRowType().getFieldList().get(nullIndicatorPos)
+ .getType(),
+ true));
+
+ // now create the new project
+ List<Pair<RexNode, String>> newProjExprs = new ArrayList<>();
+
+ // project everything from the LHS and then those from the original
+ // projRel
+ List<RelDataTypeField> leftInputFields =
+ left.getRowType().getFieldList();
+
+ for (int i = 0; i < leftInputFields.size(); i++) {
+ newProjExprs.add(RexInputRef.of2(i, leftInputFields));
+ }
+
+ // Marked where the projected expr is coming from so that the types will
+ // become nullable for the original projections which are now coming out
+ // of the nullable side of the OJ.
+ boolean projectPulledAboveLeftCorrelator =
+ joinType.generatesNullsOnRight();
+
+ for (Pair<RexNode, String> pair : project.getNamedProjects()) {
+ RexNode newProjExpr =
+ removeCorrelationExpr(
+ pair.left,
+ projectPulledAboveLeftCorrelator,
+ nullIndicator);
+
+ newProjExprs.add(Pair.of(newProjExpr, pair.right));
+ }
+
+ return relBuilder.push(join)
+ .projectNamed(Pair.left(newProjExprs), Pair.right(newProjExprs), true)
+ .build();
+ }
+
+ private JoinRelType toJoinRelType(SemiJoinType semiJoinType) {
+ switch (semiJoinType) {
+ case INNER:
+ return JoinRelType.INNER;
+ case LEFT:
+ return JoinRelType.LEFT;
+ case SEMI:
+ return JoinRelType.SEMI;
+ case ANTI:
+ return JoinRelType.ANTI;
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + semiJoinType);
+ }
+ }
+
+ /**
+ * Pulls a {@link Project} above a {@link Correlate} from its RHS input.
+ * Enforces nullability for join output.
+ *
+ * @param correlate Correlate
+ * @param project the original project as the RHS input of the join
+ * @param isCount Positions which are calls to the <code>COUNT</code>
+ * aggregation function
+ * @return the subtree with the new Project at the root
+ */
+ private RelNode aggregateCorrelatorOutput(
+ Correlate correlate,
+ LogicalProject project,
+ Set<Integer> isCount) {
+ final RelNode left = correlate.getLeft();
+ final JoinRelType joinType = toJoinRelType(correlate.getJoinType());
+
+ // now create the new project
+ final List<Pair<RexNode, String>> newProjects = new ArrayList<>();
+
+ // Project everything from the LHS and then those from the original
+ // project
+ final List<RelDataTypeField> leftInputFields =
+ left.getRowType().getFieldList();
+
+ for (int i = 0; i < leftInputFields.size(); i++) {
+ newProjects.add(RexInputRef.of2(i, leftInputFields));
+ }
+
+ // Marked where the projected expr is coming from so that the types will
+ // become nullable for the original projections which are now coming out
+ // of the nullable side of the OJ.
+ boolean projectPulledAboveLeftCorrelator =
+ joinType.generatesNullsOnRight();
+
+ for (Pair<RexNode, String> pair : project.getNamedProjects()) {
+ RexNode newProjExpr =
+ removeCorrelationExpr(
+ pair.left,
+ projectPulledAboveLeftCorrelator,
+ isCount);
+ newProjects.add(Pair.of(newProjExpr, pair.right));
+ }
+
+ return relBuilder.push(correlate)
+ .projectNamed(Pair.left(newProjects), Pair.right(newProjects), true)
+ .build();
+ }
+
+ /**
+ * Checks whether the correlations in projRel and filter are related to
+ * the correlated variables provided by corRel.
+ *
+ * @param correlate Correlate
+ * @param project The original Project as the RHS input of the join
+ * @param filter Filter
+ * @param correlatedJoinKeys Correlated join keys
+ * @return true if filter and proj only references corVar provided by corRel
+ */
+ private boolean checkCorVars(
+ LogicalCorrelate correlate,
+ LogicalProject project,
+ LogicalFilter filter,
+ List<RexFieldAccess> correlatedJoinKeys) {
+ if (filter != null) {
+ assert correlatedJoinKeys != null;
+
+ // check that all correlated refs in the filter condition are
+ // used in the join(as field access).
+ Set<CorRef> corVarInFilter =
+ Sets.newHashSet(cm.mapRefRelToCorRef.get(filter));
+
+ for (RexFieldAccess correlatedJoinKey : correlatedJoinKeys) {
+ corVarInFilter.remove(cm.mapFieldAccessToCorRef.get(correlatedJoinKey));
+ }
+
+ if (!corVarInFilter.isEmpty()) {
+ return false;
+ }
+
+ // Check that the correlated variables referenced in these
+ // comparisons do come from the Correlate.
+ corVarInFilter.addAll(cm.mapRefRelToCorRef.get(filter));
+
+ for (CorRef corVar : corVarInFilter) {
+ if (cm.mapCorToCorRel.get(corVar.corr) != correlate) {
+ return false;
+ }
+ }
+ }
+
+ // if project has any correlated reference, make sure they are also
+ // provided by the current correlate. They will be projected out of the LHS
+ // of the correlate.
+ if ((project != null) && cm.mapRefRelToCorRef.containsKey(project)) {
+ for (CorRef corVar : cm.mapRefRelToCorRef.get(project)) {
+ if (cm.mapCorToCorRel.get(corVar.corr) != correlate) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Remove correlated variables from the tree at root corRel
+ *
+ * @param correlate Correlate
+ */
+ private void removeCorVarFromTree(LogicalCorrelate correlate) {
+ if (cm.mapCorToCorRel.get(correlate.getCorrelationId()) == correlate) {
+ cm.mapCorToCorRel.remove(correlate.getCorrelationId());
+ }
+ }
+
+ /**
+ * Projects all {@code input} output fields plus the additional expressions.
+ *
+ * @param input Input relational expression
+ * @param additionalExprs Additional expressions and names
+ * @return the new Project
+ */
+ private RelNode createProjectWithAdditionalExprs(
+ RelNode input,
+ List<Pair<RexNode, String>> additionalExprs) {
+ final List<RelDataTypeField> fieldList =
+ input.getRowType().getFieldList();
+ List<Pair<RexNode, String>> projects = new ArrayList<>();
+ for (Ord<RelDataTypeField> field : Ord.zip(fieldList)) {
+ projects.add(
+ Pair.of(
+ (RexNode) relBuilder.getRexBuilder().makeInputRef(
+ field.e.getType(), field.i),
+ field.e.getName()));
+ }
+ projects.addAll(additionalExprs);
+ return relBuilder.push(input)
+ .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .build();
+ }
+
+ /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */
+ static Map<Integer, Integer> identityMap(int count) {
+ ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
+ for (int i = 0; i < count; i++) {
+ builder.put(i, i);
+ }
+ return builder.build();
+ }
+
+ /** Registers a relational expression and the relational expression it became
+ * after decorrelation. */
+ Frame register(RelNode rel, RelNode newRel,
+ Map<Integer, Integer> oldToNewOutputs,
+ SortedMap<CorDef, Integer> corDefOutputs) {
+ final Frame frame = new Frame(rel, newRel, corDefOutputs, oldToNewOutputs);
+ map.put(rel, frame);
+ return frame;
+ }
+
+ static boolean allLessThan(Collection<Integer> integers, int limit,
+ Litmus ret) {
+ for (int value : integers) {
+ if (value >= limit) {
+ return ret.fail("out of range; value: {}, limit: {}", value, limit);
+ }
+ }
+ return ret.succeed();
+ }
+
+ private static RelNode stripHep(RelNode rel) {
+ if (rel instanceof HepRelVertex) {
+ HepRelVertex hepRelVertex = (HepRelVertex) rel;
+ rel = hepRelVertex.getCurrentRel();
+ }
+ return rel;
+ }
+
+ //~ Inner Classes ----------------------------------------------------------
+
+ /** Shuttle that decorrelates. */
+ private static class DecorrelateRexShuttle extends RexShuttle {
+ private final RelNode currentRel;
+ private final Map<RelNode, Frame> map;
+ private final CorelMap cm;
+
+ private DecorrelateRexShuttle(RelNode currentRel,
+ Map<RelNode, Frame> map, CorelMap cm) {
+ this.currentRel = Objects.requireNonNull(currentRel);
+ this.map = Objects.requireNonNull(map);
+ this.cm = Objects.requireNonNull(cm);
+ }
+
+ @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+ int newInputOutputOffset = 0;
+ for (RelNode input : currentRel.getInputs()) {
+ final Frame frame = map.get(input);
+
+ if (frame != null) {
+ // try to find in this input rel the position of corVar
+ final CorRef corRef = cm.mapFieldAccessToCorRef.get(fieldAccess);
+
+ if (corRef != null) {
+ Integer newInputPos = frame.corDefOutputs.get(corRef.def());
+ if (newInputPos != null) {
+ // This input does produce the corVar referenced.
+ return new RexInputRef(newInputPos + newInputOutputOffset,
+ frame.r.getRowType().getFieldList().get(newInputPos)
+ .getType());
+ }
+ }
+
+ // this input does not produce the corVar needed
+ newInputOutputOffset += frame.r.getRowType().getFieldCount();
+ } else {
+ // this input is not rewritten
+ newInputOutputOffset += input.getRowType().getFieldCount();
+ }
+ }
+ return fieldAccess;
+ }
+
+ @Override public RexNode visitInputRef(RexInputRef inputRef) {
+ final RexInputRef ref = getNewForOldInputRef(currentRel, map, inputRef);
+ if (ref.getIndex() == inputRef.getIndex()
+ && ref.getType() == inputRef.getType()) {
+ return inputRef; // re-use old object, to prevent needless expr cloning
+ }
+ return ref;
+ }
+ }
+
+ /** Shuttle that removes correlations. */
+ private class RemoveCorrelationRexShuttle extends RexShuttle {
+ final RexBuilder rexBuilder;
+ final RelDataTypeFactory typeFactory;
+ final boolean projectPulledAboveLeftCorrelator;
+ final RexInputRef nullIndicator;
+ final ImmutableSet<Integer> isCount;
+
+ RemoveCorrelationRexShuttle(
+ RexBuilder rexBuilder,
+ boolean projectPulledAboveLeftCorrelator,
+ RexInputRef nullIndicator,
+ Set<Integer> isCount) {
+ this.projectPulledAboveLeftCorrelator =
+ projectPulledAboveLeftCorrelator;
+ this.nullIndicator = nullIndicator; // may be null
+ this.isCount = ImmutableSet.copyOf(isCount);
+ this.rexBuilder = rexBuilder;
+ this.typeFactory = rexBuilder.getTypeFactory();
+ }
+
+ private RexNode createCaseExpression(
+ RexInputRef nullInputRef,
+ RexLiteral lit,
+ RexNode rexNode) {
+ RexNode[] caseOperands = new RexNode[3];
+
+ // Construct a CASE expression to handle the null indicator.
+ //
+ // This also covers the case where a left correlated sub-query
+ // projects fields from outer relation. Since LOJ cannot produce
+ // nulls on the LHS, the projection now need to make a nullable LHS
+ // reference using a nullability indicator. If this this indicator
+ // is null, it means the sub-query does not produce any value. As a
+ // result, any RHS ref by this usbquery needs to produce null value.
+
+ // WHEN indicator IS NULL
+ caseOperands[0] =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.IS_NULL,
+ new RexInputRef(
+ nullInputRef.getIndex(),
+ typeFactory.createTypeWithNullability(
+ nullInputRef.getType(),
+ true)));
+
+ // THEN CAST(NULL AS newInputTypeNullable)
+ caseOperands[1] =
+ rexBuilder.makeCast(
+ typeFactory.createTypeWithNullability(
+ rexNode.getType(),
+ true),
+ lit);
+
+ // ELSE cast (newInput AS newInputTypeNullable) END
+ caseOperands[2] =
+ rexBuilder.makeCast(
+ typeFactory.createTypeWithNullability(
+ rexNode.getType(),
+ true),
+ rexNode);
+
+ return rexBuilder.makeCall(
+ SqlStdOperatorTable.CASE,
+ caseOperands);
+ }
+
+ @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+ if (cm.mapFieldAccessToCorRef.containsKey(fieldAccess)) {
+ // if it is a corVar, change it to be input ref.
+ CorRef corVar = cm.mapFieldAccessToCorRef.get(fieldAccess);
+
+ // corVar offset should point to the leftInput of currentRel,
+ // which is the Correlate.
+ RexNode newRexNode =
+ new RexInputRef(corVar.field, fieldAccess.getType());
+
+ if (projectPulledAboveLeftCorrelator
+ && (nullIndicator != null)) {
+ // need to enforce nullability by applying an additional
+ // cast operator over the transformed expression.
+ newRexNode =
+ createCaseExpression(
+ nullIndicator,
+ rexBuilder.constantNull(),
+ newRexNode);
+ }
+ return newRexNode;
+ }
+ return fieldAccess;
+ }
+
+ @Override public RexNode visitInputRef(RexInputRef inputRef) {
+ if (currentRel instanceof LogicalCorrelate) {
+ // if this rel references corVar
+ // and now it needs to be rewritten
+ // it must have been pulled above the Correlate
+ // replace the input ref to account for the LHS of the
+ // Correlate
+ final int leftInputFieldCount =
+ ((LogicalCorrelate) currentRel).getLeft().getRowType()
+ .getFieldCount();
+ RelDataType newType = inputRef.getType();
+
+ if (projectPulledAboveLeftCorrelator) {
+ newType =
+ typeFactory.createTypeWithNullability(newType, true);
+ }
+
+ int pos = inputRef.getIndex();
+ RexInputRef newInputRef =
+ new RexInputRef(leftInputFieldCount + pos, newType);
+
+ if ((isCount != null) && isCount.contains(pos)) {
+ return createCaseExpression(
+ newInputRef,
+ rexBuilder.makeExactLiteral(BigDecimal.ZERO),
+ newInputRef);
+ } else {
+ return newInputRef;
+ }
+ }
+ return inputRef;
+ }
+
+ @Override public RexNode visitLiteral(RexLiteral literal) {
+ // Use nullIndicator to decide whether to project null.
+ // Do nothing if the literal is null.
+ if (!RexUtil.isNull(literal)
+ && projectPulledAboveLeftCorrelator
+ && (nullIndicator != null)) {
+ return createCaseExpression(
+ nullIndicator,
+ rexBuilder.constantNull(),
+ literal);
+ }
+ return literal;
+ }
+
+ @Override public RexNode visitCall(final RexCall call) {
+ RexNode newCall;
+
+ boolean[] update = {false};
+ List<RexNode> clonedOperands = visitList(call.operands, update);
+ if (update[0]) {
+ SqlOperator operator = call.getOperator();
+
+ boolean isSpecialCast = false;
+ if (operator instanceof SqlFunction) {
+ SqlFunction function = (SqlFunction) operator;
+ if (function.getKind() == SqlKind.CAST) {
+ if (call.operands.size() < 2) {
+ isSpecialCast = true;
+ }
+ }
+ }
+
+ final RelDataType newType;
+ if (!isSpecialCast) {
+ // TODO: ideally this only needs to be called if the result
+ // type will also change. However, since that requires
+ // support from type inference rules to tell whether a rule
+ // decides return type based on input types, for now all
+ // operators will be recreated with new type if any operand
+ // changed, unless the operator has "built-in" type.
+ newType = rexBuilder.deriveReturnType(operator, clonedOperands);
+ } else {
+ // Use the current return type when creating a new call, for
+ // operators with return type built into the operator
+ // definition, and with no type inference rules, such as
+ // cast function with less than 2 operands.
+
+ // TODO: Comments in RexShuttle.visitCall() mention other
+ // types in this category. Need to resolve those together
+ // and preferably in the base class RexShuttle.
+ newType = call.getType();
+ }
+ newCall =
+ rexBuilder.makeCall(
+ newType,
+ operator,
+ clonedOperands);
+ } else {
+ newCall = call;
+ }
+
+ if (projectPulledAboveLeftCorrelator && (nullIndicator != null)) {
+ return createCaseExpression(
+ nullIndicator,
+ rexBuilder.constantNull(),
+ newCall);
+ }
+ return newCall;
+ }
+ }
+
+ /**
+ * Rule to remove single_value rel. For cases like
+ *
+ * <blockquote>AggRel single_value proj/filter/agg/ join on unique LHS key
+ * AggRel single group</blockquote>
+ */
+ private final class RemoveSingleAggregateRule extends RelOptRule {
+ RemoveSingleAggregateRule(RelBuilderFactory relBuilderFactory) {
+ super(
+ operand(
+ LogicalAggregate.class,
+ operand(
+ LogicalProject.class,
+ operand(LogicalAggregate.class, any()))),
+ relBuilderFactory, null);
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ LogicalAggregate singleAggregate = call.rel(0);
+ LogicalProject project = call.rel(1);
+ LogicalAggregate aggregate = call.rel(2);
+
+ // check singleAggRel is single_value agg
+ if ((!singleAggregate.getGroupSet().isEmpty())
+ || (singleAggregate.getAggCallList().size() != 1)
+ || !(singleAggregate.getAggCallList().get(0).getAggregation()
+ instanceof SqlSingleValueAggFunction)) {
+ return;
+ }
+
+ // check projRel only projects one expression
+ // check this project only projects one expression, i.e. scalar
+ // sub-queries.
+ List<RexNode> projExprs = project.getProjects();
+ if (projExprs.size() != 1) {
+ return;
+ }
+
+ // check the input to project is an aggregate on the entire input
+ if (!aggregate.getGroupSet().isEmpty()) {
+ return;
+ }
+
+ // singleAggRel produces a nullable type, so create the new
+ // projection that casts proj expr to a nullable type.
+ final RelBuilder relBuilder = call.builder();
+ final RelDataType type =
+ relBuilder.getTypeFactory()
+ .createTypeWithNullability(projExprs.get(0).getType(), true);
+ final RexNode cast =
+ relBuilder.getRexBuilder().makeCast(type, projExprs.get(0));
+ relBuilder.push(aggregate)
+ .project(cast);
+ call.transformTo(relBuilder.build());
+ }
+ }
+
+ /** Planner rule that removes correlations for scalar projects. */
+ private final class RemoveCorrelationForScalarProjectRule extends RelOptRule {
+ RemoveCorrelationForScalarProjectRule(RelBuilderFactory relBuilderFactory) {
+ super(
+ operand(LogicalCorrelate.class,
+ operand(RelNode.class, any()),
+ operand(LogicalAggregate.class,
+ operand(LogicalProject.class,
+ operand(RelNode.class, any())))),
+ relBuilderFactory, null);
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalCorrelate correlate = call.rel(0);
+ final RelNode left = call.rel(1);
+ final LogicalAggregate aggregate = call.rel(2);
+ final LogicalProject project = call.rel(3);
+ RelNode right = call.rel(4);
+ final RelOptCluster cluster = correlate.getCluster();
+
+ setCurrent(call.getPlanner().getRoot(), correlate);
+
+ // Check for this pattern.
+ // The pattern matching could be simplified if rules can be applied
+ // during decorrelation.
+ //
+ // Correlate(left correlation, condition = true)
+ // leftInput
+ // Aggregate (groupby (0) single_value())
+ // Project-A (may reference corVar)
+ // rightInput
+ final JoinRelType joinType = toJoinRelType(correlate.getJoinType());
+
+ // corRel.getCondition was here, however Correlate was updated so it
+ // never includes a join condition. The code was not modified for brevity.
+ RexNode joinCond = relBuilder.literal(true);
+ if ((joinType != JoinRelType.LEFT)
+ || (joinCond != relBuilder.literal(true))) {
+ return;
+ }
+
+ // check that the agg is of the following type:
+ // doing a single_value() on the entire input
+ if ((!aggregate.getGroupSet().isEmpty())
+ || (aggregate.getAggCallList().size() != 1)
+ || !(aggregate.getAggCallList().get(0).getAggregation()
+ instanceof SqlSingleValueAggFunction)) {
+ return;
+ }
+
+ // check this project only projects one expression, i.e. scalar
+ // sub-queries.
+ if (project.getProjects().size() != 1) {
+ return;
+ }
+
+ int nullIndicatorPos;
+
+ if ((right instanceof LogicalFilter)
+ && cm.mapRefRelToCorRef.containsKey(right)) {
+ // rightInput has this shape:
+ //
+ // Filter (references corVar)
+ // filterInput
+
+ // If rightInput is a filter and contains correlated
+ // reference, make sure the correlated keys in the filter
+ // condition forms a unique key of the RHS.
+
+ LogicalFilter filter = (LogicalFilter) right;
+ right = filter.getInput();
+
+ assert right instanceof HepRelVertex;
+ right = ((HepRelVertex) right).getCurrentRel();
+
+ // check filter input contains no correlation
+ if (RelOptUtil.getVariablesUsed(right).size() > 0) {
+ return;
+ }
+
+ // extract the correlation out of the filter
+
+ // First breaking up the filter conditions into equality
+ // comparisons between rightJoinKeys (from the original
+ // filterInput) and correlatedJoinKeys. correlatedJoinKeys
+ // can be expressions, while rightJoinKeys need to be input
+ // refs. These comparisons are AND'ed together.
+ List<RexNode> tmpRightJoinKeys = new ArrayList<>();
+ List<RexNode> correlatedJoinKeys = new ArrayList<>();
+ RelOptUtil.splitCorrelatedFilterCondition(
+ filter,
+ tmpRightJoinKeys,
+ correlatedJoinKeys,
+ false);
+
+ // check that the columns referenced in these comparisons form
+ // an unique key of the filterInput
+ final List<RexInputRef> rightJoinKeys = new ArrayList<>();
+ for (RexNode key : tmpRightJoinKeys) {
+ assert key instanceof RexInputRef;
+ rightJoinKeys.add((RexInputRef) key);
+ }
+
+ // check that the columns referenced in rightJoinKeys form an
+ // unique key of the filterInput
+ if (rightJoinKeys.isEmpty()) {
+ return;
+ }
+
+ // The join filters out the nulls. So, it's ok if there are
+ // nulls in the join keys.
+ final RelMetadataQuery mq = call.getMetadataQuery();
+ if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, right,
+ rightJoinKeys)) {
+ SQL2REL_LOGGER.debug("{} are not unique keys for {}",
+ rightJoinKeys.toString(), right.toString());
+ return;
+ }
+
+ RexUtil.FieldAccessFinder visitor =
+ new RexUtil.FieldAccessFinder();
+ RexUtil.apply(visitor, correlatedJoinKeys, null);
+ List<RexFieldAccess> correlatedKeyList =
+ visitor.getFieldAccessList();
+
+ if (!checkCorVars(correlate, project, filter, correlatedKeyList)) {
+ return;
+ }
+
+ // Change the plan to this structure.
+ // Note that the Aggregate is removed.
+ //
+ // Project-A' (replace corVar to input ref from the Join)
+ // Join (replace corVar to input ref from leftInput)
+ // leftInput
+ // rightInput (previously filterInput)
+
+ // Change the filter condition into a join condition
+ joinCond =
+ removeCorrelationExpr(filter.getCondition(), false);
+
+ nullIndicatorPos =
+ left.getRowType().getFieldCount()
+ + rightJoinKeys.get(0).getIndex();
+ } else if (cm.mapRefRelToCorRef.containsKey(project)) {
+ // check filter input contains no correlation
+ if (RelOptUtil.getVariablesUsed(right).size() > 0) {
+ return;
+ }
+
+ if (!checkCorVars(correlate, project, null, null)) {
+ return;
+ }
+
+ // Change the plan to this structure.
+ //
+ // Project-A' (replace corVar to input ref from Join)
+ // Join (left, condition = true)
+ // leftInput
+ // Aggregate(groupby(0), single_value(0), s_v(1)....)
+ // Project-B (everything from input plus literal true)
+ // projectInput
+
+ // make the new Project to provide a null indicator
+ right =
+ createProjectWithAdditionalExprs(right,
+ ImmutableList.of(
+ Pair.of(relBuilder.literal(true), "nullIndicator")));
+
+ // make the new aggRel
+ right =
+ RelOptUtil.createSingleValueAggRel(cluster, right);
+
+ // The last field:
+ // single_value(true)
+ // is the nullIndicator
+ nullIndicatorPos =
+ left.getRowType().getFieldCount()
+ + right.getRowType().getFieldCount() - 1;
+ } else {
+ return;
+ }
+
+ // make the new join rel
+ LogicalJoin join =
+ LogicalJoin.create(left, right, joinCond,
+ ImmutableSet.of(), joinType);
+
+ RelNode newProject =
+ projectJoinOutputWithNullability(join, project, nullIndicatorPos);
+
+ call.transformTo(newProject);
+
+ removeCorVarFromTree(correlate);
+ }
+ }
+
+ /** Planner rule that removes correlations for scalar aggregates. */
+ private final class RemoveCorrelationForScalarAggregateRule
+ extends RelOptRule {
+ RemoveCorrelationForScalarAggregateRule(RelBuilderFactory relBuilderFactory) {
+ super(
+ operand(LogicalCorrelate.class,
+ operand(RelNode.class, any()),
+ operand(LogicalProject.class,
+ operandJ(LogicalAggregate.class, null, Aggregate::isSimple,
+ operand(LogicalProject.class,
+ operand(RelNode.class, any()))))),
+ relBuilderFactory, null);
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalCorrelate correlate = call.rel(0);
+ final RelNode left = call.rel(1);
+ final LogicalProject aggOutputProject = call.rel(2);
+ final LogicalAggregate aggregate = call.rel(3);
+ final LogicalProject aggInputProject = call.rel(4);
+ RelNode right = call.rel(5);
+ final RelBuilder builder = call.builder();
+ final RexBuilder rexBuilder = builder.getRexBuilder();
+ final RelOptCluster cluster = correlate.getCluster();
+
+ setCurrent(call.getPlanner().getRoot(), correlate);
+
+ // check for this pattern
+ // The pattern matching could be simplified if rules can be applied
+ // during decorrelation,
+ //
+ // CorrelateRel(left correlation, condition = true)
+ // leftInput
+ // Project-A (a RexNode)
+ // Aggregate (groupby (0), agg0(), agg1()...)
+ // Project-B (references coVar)
+ // rightInput
+
+ // check aggOutputProject projects only one expression
+ final List<RexNode> aggOutputProjects = aggOutputProject.getProjects();
+ if (aggOutputProjects.size() != 1) {
+ return;
+ }
+
+ final JoinRelType joinType = toJoinRelType(correlate.getJoinType());
+ // corRel.getCondition was here, however Correlate was updated so it
+ // never includes a join condition. The code was not modified for brevity.
+ RexNode joinCond = rexBuilder.makeLiteral(true);
+ if ((joinType != JoinRelType.LEFT)
+ || (joinCond != rexBuilder.makeLiteral(true))) {
+ return;
+ }
+
+ // check that the agg is on the entire input
+ if (!aggregate.getGroupSet().isEmpty()) {
+ return;
+ }
+
+ final List<RexNode> aggInputProjects = aggInputProject.getProjects();
+
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+ final Set<Integer> isCountStar = new HashSet<>();
+
+ // mark if agg produces count(*) which needs to reference the
+ // nullIndicator after the transformation.
+ int k = -1;
+ for (AggregateCall aggCall : aggCalls) {
+ ++k;
+ if ((aggCall.getAggregation() instanceof SqlCountAggFunction)
+ && (aggCall.getArgList().size() == 0)) {
+ isCountStar.add(k);
+ }
+ }
+
+ if ((right instanceof LogicalFilter)
+ && cm.mapRefRelToCorRef.containsKey(right)) {
+ // rightInput has this shape:
+ //
+ // Filter (references corVar)
+ // filterInput
+ LogicalFilter filter = (LogicalFilter) right;
+ right = filter.getInput();
+
+ assert right instanceof HepRelVertex;
+ right = ((HepRelVertex) right).getCurrentRel();
+
+ // check filter input contains no correlation
+ if (RelOptUtil.getVariablesUsed(right).size() > 0) {
+ return;
+ }
+
+ // check filter condition type First extract the correlation out
+ // of the filter
+
+ // First breaking up the filter conditions into equality
+ // comparisons between rightJoinKeys(from the original
+ // filterInput) and correlatedJoinKeys. correlatedJoinKeys
+ // can only be RexFieldAccess, while rightJoinKeys can be
+ // expressions. These comparisons are AND'ed together.
+ List<RexNode> rightJoinKeys = new ArrayList<>();
+ List<RexNode> tmpCorrelatedJoinKeys = new ArrayList<>();
+ RelOptUtil.splitCorrelatedFilterCondition(
+ filter,
+ rightJoinKeys,
+ tmpCorrelatedJoinKeys,
+ true);
+
+ // make sure the correlated reference forms a unique key check
+ // that the columns referenced in these comparisons form an
+ // unique key of the leftInput
+ List<RexFieldAccess> correlatedJoinKeys = new ArrayList<>();
+ List<RexInputRef> correlatedInputRefJoinKeys = new ArrayList<>();
+ for (RexNode joinKey : tmpCorrelatedJoinKeys) {
+ assert joinKey instanceof RexFieldAccess;
+ correlatedJoinKeys.add((RexFieldAccess) joinKey);
+ RexNode correlatedInputRef =
+ removeCorrelationExpr(joinKey, false);
+ assert correlatedInputRef instanceof RexInputRef;
+ correlatedInputRefJoinKeys.add(
+ (RexInputRef) correlatedInputRef);
+ }
+
+ // check that the columns referenced in rightJoinKeys form an
+ // unique key of the filterInput
+ if (correlatedInputRefJoinKeys.isEmpty()) {
+ return;
+ }
+
+ // The join filters out the nulls. So, it's ok if there are
+ // nulls in the join keys.
+ final RelMetadataQuery mq = call.getMetadataQuery();
+ if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, left,
+ correlatedInputRefJoinKeys)) {
+ SQL2REL_LOGGER.debug("{} are not unique keys for {}",
+ correlatedJoinKeys.toString(), left.toString());
+ return;
+ }
+
+ // check corVar references are valid
+ if (!checkCorVars(correlate,
+ aggInputProject,
+ filter,
+ correlatedJoinKeys)) {
+ return;
+ }
+
+ // Rewrite the above plan:
+ //
+ // Correlate(left correlation, condition = true)
+ // leftInput
+ // Project-A (a RexNode)
+ // Aggregate (groupby(0), agg0(),agg1()...)
+ // Project-B (may reference corVar)
+ // Filter (references corVar)
+ // rightInput (no correlated reference)
+ //
+
+ // to this plan:
+ //
+ // Project-A' (all gby keys + rewritten nullable ProjExpr)
+ // Aggregate (groupby(all left input refs)
+ // agg0(rewritten expression),
+ // agg1()...)
+ // Project-B' (rewritten original projected exprs)
+ // Join(replace corVar w/ input ref from leftInput)
+ // leftInput
+ // rightInput
+ //
+
+ // In the case where agg is count(*) or count($corVar), it is
+ // changed to count(nullIndicator).
+ // Note: any non-nullable field from the RHS can be used as
+ // the indicator however a "true" field is added to the
+ // projection list from the RHS for simplicity to avoid
+ // searching for non-null fields.
+ //
+ // Project-A' (all gby keys + rewritten nullable ProjExpr)
+ // Aggregate (groupby(all left input refs),
+ // count(nullIndicator), other aggs...)
+ // Project-B' (all left input refs plus
+ // the rewritten original projected exprs)
+ // Join(replace corVar to input ref from leftInput)
+ // leftInput
+ // Project (everything from rightInput plus
+ // the nullIndicator "true")
+ // rightInput
+ //
+
+ // first change the filter condition into a join condition
+ joinCond =
+ removeCorrelationExpr(filter.getCondition(), false);
+ } else if (cm.mapRefRelToCorRef.containsKey(aggInputProject)) {
+ // check rightInput contains no correlation
+ if (RelOptUtil.getVariablesUsed(right).size() > 0) {
+ return;
+ }
+
+ // check corVar references are valid
+ if (!checkCorVars(correlate, aggInputProject, null, null)) {
+ return;
+ }
+
+ int nFields = left.getRowType().getFieldCount();
+ ImmutableBitSet allCols = ImmutableBitSet.range(nFields);
+
+ // leftInput contains unique keys
+ // i.e. each row is distinct and can group by on all the left
+ // fields
+ final RelMetadataQuery mq = call.getMetadataQuery();
+ if (!RelMdUtil.areColumnsDefinitelyUnique(mq, left, allCols)) {
+ SQL2REL_LOGGER.debug("There are no unique keys for {}", left);
+ return;
+ }
+ //
+ // Rewrite the above plan:
+ //
+ // CorrelateRel(left correlation, condition = true)
+ // leftInput
+ // Project-A (a RexNode)
+ // Aggregate (groupby(0), agg0(), agg1()...)
+ // Project-B (references coVar)
+ // rightInput (no correlated reference)
+ //
+
+ // to this plan:
+ //
+ // Project-A' (all gby keys + rewritten nullable ProjExpr)
+ // Aggregate (groupby(all left input refs)
+ // agg0(rewritten expression),
+ // agg1()...)
+ // Project-B' (rewritten original projected exprs)
+ // Join (LOJ cond = true)
+ // leftInput
+ // rightInput
+ //
+
+ // In the case where agg is count($corVar), it is changed to
+ // count(nullIndicator).
+ // Note: any non-nullable field from the RHS can be used as
+ // the indicator however a "true" field is added to the
+ // projection list from the RHS for simplicity to avoid
+ // searching for non-null fields.
+ //
+ // Project-A' (all gby keys + rewritten nullable ProjExpr)
+ // Aggregate (groupby(all left input refs),
+ // count(nullIndicator), other aggs...)
+ // Project-B' (all left input refs plus
+ // the rewritten original projected exprs)
+ // Join (replace corVar to input ref from leftInput)
+ // leftInput
+ // Project (everything from rightInput plus
+ // the nullIndicator "true")
+ // rightInput
+ } else {
+ return;
+ }
+
+ RelDataType leftInputFieldType = left.getRowType();
+ int leftInputFieldCount = leftInputFieldType.getFieldCount();
+ int joinOutputProjExprCount =
+ leftInputFieldCount + aggInputProjects.size() + 1;
+
+ right =
+ createProjectWithAdditionalExprs(right,
+ ImmutableList.of(
+ Pair.of(rexBuilder.makeLiteral(true),
+ "nullIndicator")));
+
+ LogicalJoin join =
+ LogicalJoin.create(left, right, joinCond,
+ ImmutableSet.of(), joinType);
+
+ // To the consumer of joinOutputProjRel, nullIndicator is located
+ // at the end
+ int nullIndicatorPos = join.getRowType().getFieldCount() - 1;
+
+ RexInputRef nullIndicator =
+ new RexInputRef(
+ nullIndicatorPos,
+ cluster.getTypeFactory().createTypeWithNullability(
+ join.getRowType().getFieldList()
+ .get(nullIndicatorPos).getType(),
+ true));
+
+ // first project all group-by keys plus the transformed agg input
+ List<RexNode> joinOutputProjects = new ArrayList<>();
+
+ // LOJ Join preserves LHS types
+ for (int i = 0; i < leftInputFieldCount; i++) {
+ joinOutputProjects.add(
+ rexBuilder.makeInputRef(
+ leftInputFieldType.getFieldList().get(i).getType(), i));
+ }
+
+ for (RexNode aggInputProjExpr : aggInputProjects) {
+ joinOutputProjects.add(
+ removeCorrelationExpr(aggInputProjExpr,
+ joinType.generatesNullsOnRight(),
+ nullIndicator));
+ }
+
+ joinOutputProjects.add(
+ rexBuilder.makeInputRef(join, nullIndicatorPos));
+
+ final RelNode joinOutputProject = builder.push(join)
+ .project(joinOutputProjects)
+ .build();
+
+ // nullIndicator is now at a different location in the output of
+ // the join
+ nullIndicatorPos = joinOutputProjExprCount - 1;
+
+ final int groupCount = leftInputFieldCount;
+
+ List<AggregateCall> newAggCalls = new ArrayList<>();
+ k = -1;
+ for (AggregateCall aggCall : aggCalls) {
+ ++k;
+ final List<Integer> argList;
+
+ if (isCountStar.contains(k)) {
+ // this is a count(*), transform it to count(nullIndicator)
+ // the null indicator is located at the end
+ argList = Collections.singletonList(nullIndicatorPos);
+ } else {
+ argList = new ArrayList<>();
+
+ for (int aggArg : aggCall.getArgList()) {
+ argList.add(aggArg + groupCount);
+ }
+ }
+
+ int filterArg = aggCall.filterArg < 0 ? aggCall.filterArg
+ : aggCall.filterArg + groupCount;
+ newAggCalls.add(
+ aggCall.adaptTo(joinOutputProject, argList, filterArg,
+ aggregate.getGroupCount(), groupCount));
+ }
+
+ ImmutableBitSet groupSet =
+ ImmutableBitSet.range(groupCount);
+ LogicalAggregate newAggregate =
+ LogicalAggregate.create(joinOutputProject, groupSet, null,
+ newAggCalls);
+ List<RexNode> newAggOutputProjectList = new ArrayList<>();
+ for (int i : groupSet) {
+ newAggOutputProjectList.add(
+ rexBuilder.makeInputRef(newAggregate, i));
+ }
+
+ RexNode newAggOutputProjects =
+ removeCorrelationExpr(aggOutputProjects.get(0), false);
+ newAggOutputProjectList.add(
+ rexBuilder.makeCast(
+ cluster.getTypeFactory().createTypeWithNullability(
+ newAggOutputProjects.getType(),
+ true),
+ newAggOutputProjects));
+
+ builder.push(newAggregate)
+ .project(newAggOutputProjectList);
+ call.transformTo(builder.build());
+
+ removeCorVarFromTree(correlate);
+ }
+ }
+
+ // REVIEW jhyde 29-Oct-2007: This rule is non-static, depends on the state
+ // of members in RelDecorrelator, and has side-effects in the decorrelator.
+ // This breaks the contract of a planner rule, and the rule will not be
+ // reusable in other planners.
+
+ // REVIEW jvs 29-Oct-2007: Shouldn't it also be incorporating
+ // the flavor attribute into the description?
+
+ /** Planner rule that adjusts projects when counts are added. */
+ private final class AdjustProjectForCountAggregateRule extends RelOptRule {
+ final boolean flavor;
+
+ AdjustProjectForCountAggregateRule(boolean flavor,
+ RelBuilderFactory relBuilderFactory) {
+ super(
+ flavor
+ ? operand(LogicalCorrelate.class,
+ operand(RelNode.class, any()),
+ operand(LogicalProject.class,
+ operand(LogicalAggregate.class, any())))
+ : operand(LogicalCorrelate.class,
+ operand(RelNode.class, any()),
+ operand(LogicalAggregate.class, any())),
+ relBuilderFactory, null);
+ this.flavor = flavor;
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalCorrelate correlate = call.rel(0);
+ final RelNode left = call.rel(1);
+ final LogicalProject aggOutputProject;
+ final LogicalAggregate aggregate;
+ if (flavor) {
+ aggOutputProject = call.rel(2);
+ aggregate = call.rel(3);
+ } else {
+ aggregate = call.rel(2);
+
+ // Create identity projection
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+ final List<RelDataTypeField> fields =
+ aggregate.getRowType().getFieldList();
+ for (int i = 0; i < fields.size(); i++) {
+ projects.add(RexInputRef.of2(projects.size(), fields));
+ }
+ final RelBuilder relBuilder = call.builder();
+ relBuilder.push(aggregate)
+ .projectNamed(Pair.left(projects), Pair.right(projects), true);
+ aggOutputProject = (LogicalProject) relBuilder.build();
+ }
+ onMatch2(call, correlate, left, aggOutputProject, aggregate);
+ }
+
+ private void onMatch2(
+ RelOptRuleCall call,
+ LogicalCorrelate correlate,
+ RelNode leftInput,
+ LogicalProject aggOutputProject,
+ LogicalAggregate aggregate) {
+ if (generatedCorRels.contains(correlate)) {
+ // This Correlate was generated by a previous invocation of
+ // this rule. No further work to do.
+ return;
+ }
+
+ setCurrent(call.getPlanner().getRoot(), correlate);
+
+ // check for this pattern
+ // The pattern matching could be simplified if rules can be applied
+ // during decorrelation,
+ //
+ // CorrelateRel(left correlation, condition = true)
+ // leftInput
+ // Project-A (a RexNode)
+ // Aggregate (groupby (0), agg0(), agg1()...)
+
+ // check aggOutputProj projects only one expression
+ List<RexNode> aggOutputProjExprs = aggOutputProject.getProjects();
+ if (aggOutputProjExprs.size() != 1) {
+ return;
+ }
+
+ JoinRelType joinType = toJoinRelType(correlate.getJoinType());
+ // corRel.getCondition was here, however Correlate was updated so it
+ // never includes a join condition. The code was not modified for brevity.
+ RexNode joinCond = relBuilder.literal(true);
+ if ((joinType != JoinRelType.LEFT)
+ || (joinCond != relBuilder.literal(true))) {
+ return;
+ }
+
+ // check that the agg is on the entire input
+ if (!aggregate.getGroupSet().isEmpty()) {
+ return;
+ }
+
+ List<AggregateCall> aggCalls = aggregate.getAggCallList();
+ Set<Integer> isCount = new HashSet<>();
+
+ // remember the count() positions
+ int i = -1;
+ for (AggregateCall aggCall : aggCalls) {
+ ++i;
+ if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
+ isCount.add(i);
+ }
+ }
+
+ // now rewrite the plan to
+ //
+ // Project-A' (all LHS plus transformed original projections,
+ // replacing references to count() with case statement)
+ // Correlate(left correlation, condition = true)
+ // leftInput
+ // Aggregate(groupby (0), agg0(), agg1()...)
+ //
+ LogicalCorrelate newCorrelate =
+ LogicalCorrelate.create(leftInput, aggregate,
+ correlate.getCorrelationId(), correlate.getRequiredColumns(),
+ correlate.getJoinType());
+
+ // remember this rel so we don't fire rule on it again
+ // REVIEW jhyde 29-Oct-2007: rules should not save state; rule
+ // should recognize patterns where it does or does not need to do
+ // work
+ generatedCorRels.add(newCorrelate);
+
+ // need to update the mapCorToCorRel Update the output position
+ // for the corVars: only pass on the corVars that are not used in
+ // the join key.
+ if (cm.mapCorToCorRel.get(correlate.getCorrelationId()) == correlate) {
+ cm.mapCorToCorRel.put(correlate.getCorrelationId(), newCorrelate);
+ }
+
+ RelNode newOutput =
+ aggregateCorrelatorOutput(newCorrelate, aggOutputProject, isCount);
+
+ call.transformTo(newOutput);
+ }
+ }
+
+ /**
+ * A unique reference to a correlation field.
+ *
+ * <p>For instance, if a RelNode references emp.name multiple times, it would
+ * result in multiple {@code CorRef} objects that differ just in
+ * {@link CorRef#uniqueKey}.
+ */
+ static class CorRef implements Comparable<CorRef> {
+ public final int uniqueKey;
+ public final CorrelationId corr;
+ public final int field;
+
+ CorRef(CorrelationId corr, int field, int uniqueKey) {
+ this.corr = corr;
+ this.field = field;
+ this.uniqueKey = uniqueKey;
+ }
+
+ @Override public String toString() {
+ return corr.getName() + '.' + field;
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(uniqueKey, corr, field);
+ }
+
+ @Override public boolean equals(Object o) {
+ return this == o
+ || o instanceof CorRef
+ && uniqueKey == ((CorRef) o).uniqueKey
+ && corr == ((CorRef) o).corr
+ && field == ((CorRef) o).field;
+ }
+
+ public int compareTo(@Nonnull CorRef o) {
+ int c = corr.compareTo(o.corr);
+ if (c != 0) {
+ return c;
+ }
+ c = Integer.compare(field, o.field);
+ if (c != 0) {
+ return c;
+ }
+ return Integer.compare(uniqueKey, o.uniqueKey);
+ }
+
+ public CorDef def() {
+ return new CorDef(corr, field);
+ }
+ }
+
+ /** A correlation and a field. */
+ static class CorDef implements Comparable<CorDef> {
+ public final CorrelationId corr;
+ public final int field;
+
+ CorDef(CorrelationId corr, int field) {
+ this.corr = corr;
+ this.field = field;
+ }
+
+ @Override public String toString() {
+ return corr.getName() + '.' + field;
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(corr, field);
+ }
+
+ @Override public boolean equals(Object o) {
+ return this == o
+ || o instanceof CorDef
+ && corr == ((CorDef) o).corr
+ && field == ((CorDef) o).field;
+ }
+
+ public int compareTo(@Nonnull CorDef o) {
+ int c = corr.compareTo(o.corr);
+ if (c != 0) {
+ return c;
+ }
+ return Integer.compare(field, o.field);
+ }
+ }
+
+ /** A map of the locations of
+ * {@link org.apache.calcite.rel.logical.LogicalCorrelate}
+ * in a tree of {@link RelNode}s.
+ *
+ * <p>It is used to drive the decorrelation process.
+ * Treat it as immutable; rebuild if you modify the tree.
+ *
+ * <p>There are three maps:<ol>
+ *
+ * <li>{@link #mapRefRelToCorRef} maps a {@link RelNode} to the correlated
+ * variables it references;
+ *
+ * <li>{@link #mapCorToCorRel} maps a correlated variable to the
+ * {@link Correlate} providing it;
+ *
+ * <li>{@link #mapFieldAccessToCorRef} maps a rex field access to
+ * the corVar it represents. Because typeFlattener does not clone or
+ * modify a correlated field access this map does not need to be
+ * updated.
+ *
+ * </ol> */
+ private static class CorelMap {
+ private final Multimap<RelNode, CorRef> mapRefRelToCorRef;
+ private final SortedMap<CorrelationId, RelNode> mapCorToCorRel;
+ private final Map<RexFieldAccess, CorRef> mapFieldAccessToCorRef;
+
+ // TODO: create immutable copies of all maps
+ private CorelMap(Multimap<RelNode, CorRef> mapRefRelToCorRef,
+ SortedMap<CorrelationId, RelNode> mapCorToCorRel,
+ Map<RexFieldAccess, CorRef> mapFieldAccessToCorRef) {
+ this.mapRefRelToCorRef = mapRefRelToCorRef;
+ this.mapCorToCorRel = mapCorToCorRel;
+ this.mapFieldAccessToCorRef = ImmutableMap.copyOf(mapFieldAccessToCorRef);
+ }
+
+ @Override public String toString() {
+ return "mapRefRelToCorRef=" + mapRefRelToCorRef
+ + "\nmapCorToCorRel=" + mapCorToCorRel
+ + "\nmapFieldAccessToCorRef=" + mapFieldAccessToCorRef
+ + "\n";
+ }
+
+ @Override public boolean equals(Object obj) {
+ return obj == this
+ || obj instanceof CorelMap
+ && mapRefRelToCorRef.equals(((CorelMap) obj).mapRefRelToCorRef)
+ && mapCorToCorRel.equals(((CorelMap) obj).mapCorToCorRel)
+ && mapFieldAccessToCorRef.equals(
+ ((CorelMap) obj).mapFieldAccessToCorRef);
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(mapRefRelToCorRef, mapCorToCorRel,
+ mapFieldAccessToCorRef);
+ }
+
+ /** Creates a CorelMap with given contents. */
+ public static CorelMap of(
+ SortedSetMultimap<RelNode, CorRef> mapRefRelToCorVar,
+ SortedMap<CorrelationId, RelNode> mapCorToCorRel,
+ Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar) {
+ return new CorelMap(mapRefRelToCorVar, mapCorToCorRel,
+ mapFieldAccessToCorVar);
+ }
+
+ /**
+ * Returns whether there are any correlating variables in this statement.
+ *
+ * @return whether there are any correlating variables
+ */
+ public boolean hasCorrelation() {
+ return !mapCorToCorRel.isEmpty();
+ }
+ }
+
+ /** Builds a {@link org.apache.calcite.sql2rel.RelDecorrelator.CorelMap}. */
+ private static class CorelMapBuilder extends RelShuttleImpl {
+ final SortedMap<CorrelationId, RelNode> mapCorToCorRel =
+ new TreeMap<>();
+
+ final SortedSetMultimap<RelNode, CorRef> mapRefRelToCorRef =
+ MultimapBuilder.SortedSetMultimapBuilder.hashKeys()
+ .treeSetValues()
+ .build();
+
+ final Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar = new HashMap<>();
+
+ final Holder<Integer> offset = Holder.of(0);
+ int corrIdGenerator = 0;
+
+ /** Creates a CorelMap by iterating over a {@link RelNode} tree. */
+ CorelMap build(RelNode... rels) {
+ for (RelNode rel : rels) {
+ stripHep(rel).accept(this);
+ }
+ return new CorelMap(mapRefRelToCorRef, mapCorToCorRel,
+ mapFieldAccessToCorVar);
+ }
+
+ @Override public RelNode visit(LogicalJoin join) {
+ try {
+ stack.push(join);
+ join.getCondition().accept(rexVisitor(join));
+ } finally {
+ stack.pop();
+ }
+ return visitJoin(join);
+ }
+
+ @Override protected RelNode visitChild(RelNode parent, int i,
+ RelNode input) {
+ return super.visitChild(parent, i, stripHep(input));
+ }
+
+ @Override public RelNode visit(LogicalCorrelate correlate) {
+ mapCorToCorRel.put(correlate.getCorrelationId(), correlate);
+ return visitJoin(correlate);
+ }
+
+ private RelNode visitJoin(BiRel join) {
+ final int x = offset.get();
+ visitChild(join, 0, join.getLeft());
+ offset.set(x + join.getLeft().getRowType().getFieldCount());
+ visitChild(join, 1, join.getRight());
+ offset.set(x);
+ return join;
+ }
+
+ @Override public RelNode visit(final LogicalFilter filter) {
+ try {
+ stack.push(filter);
+ filter.getCondition().accept(rexVisitor(filter));
+ } finally {
+ stack.pop();
+ }
+ return super.visit(filter);
+ }
+
+ @Override public RelNode visit(LogicalProject project) {
+ try {
+ stack.push(project);
+ for (RexNode node : project.getProjects()) {
+ node.accept(rexVisitor(project));
+ }
+ } finally {
+ stack.pop();
+ }
+ return super.visit(project);
+ }
+
+ private RexVisitorImpl<Void> rexVisitor(final RelNode rel) {
+ return new RexVisitorImpl<Void>(true) {
+ @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) {
+ final RexNode ref = fieldAccess.getReferenceExpr();
+ if (ref instanceof RexCorrelVariable) {
+ final RexCorrelVariable var = (RexCorrelVariable) ref;
+ if (mapFieldAccessToCorVar.containsKey(fieldAccess)) {
+ // for cases where different Rel nodes are referring to
+ // same correlation var (e.g. in case of NOT IN)
+ // avoid generating another correlation var
+ // and record the 'rel' is using the same correlation
+ mapRefRelToCorRef.put(rel,
+ mapFieldAccessToCorVar.get(fieldAccess));
+ } else {
+ final CorRef correlation =
+ new CorRef(var.id, fieldAccess.getField().getIndex(),
+ corrIdGenerator++);
+ mapFieldAccessToCorVar.put(fieldAccess, correlation);
+ mapRefRelToCorRef.put(rel, correlation);
+ }
+ }
+ return super.visitFieldAccess(fieldAccess);
+ }
+
+ @Override public Void visitSubQuery(RexSubQuery subQuery) {
+ subQuery.rel.accept(CorelMapBuilder.this);
+ return super.visitSubQuery(subQuery);
+ }
+ };
+ }
+ }
+
+ /** Frame describing the relational expression after decorrelation
+ * and where to find the output fields and correlation variables
+ * among its output fields. */
+ static class Frame {
+ final RelNode r;
+ final ImmutableSortedMap<CorDef, Integer> corDefOutputs;
+ final ImmutableSortedMap<Integer, Integer> oldToNewOutputs;
+
+ Frame(RelNode oldRel, RelNode r, SortedMap<CorDef, Integer> corDefOutputs,
+ Map<Integer, Integer> oldToNewOutputs) {
+ this.r = Objects.requireNonNull(r);
+ this.corDefOutputs = ImmutableSortedMap.copyOf(corDefOutputs);
+ this.oldToNewOutputs = ImmutableSortedMap.copyOf(oldToNewOutputs);
+ assert allLessThan(this.corDefOutputs.values(),
+ r.getRowType().getFieldCount(), Litmus.THROW);
+ assert allLessThan(this.oldToNewOutputs.keySet(),
+ oldRel.getRowType().getFieldCount(), Litmus.THROW);
+ assert allLessThan(this.oldToNewOutputs.values(),
+ r.getRowType().getFieldCount(), Litmus.THROW);
+ }
+ }
+}
+
+// End RelDecorrelator.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/api/PlannerConfigOptions.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/api/PlannerConfigOptions.java
index 48b87c8..557913f 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/api/PlannerConfigOptions.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/api/PlannerConfigOptions.java
@@ -73,6 +73,12 @@ public class PlannerConfigOptions {
"instance that holds a partition of all data when performing a hash join. " +
"Broadcast will be disabled if the value is -1.");
+ public static final ConfigOption<Double> SQL_OPTIMIZER_SEMI_JOIN_BUILD_DISTINCT_NDV_RATIO =
+ key("sql.optimizer.semi-anti-join.build-distinct.ndv-ratio")
+ .defaultValue(0.8)
+ .withDescription("When the semi-side of semi/anti join can distinct a lot of data in advance," +
+ " we will add distinct node before semi/anti join.");
+
public static final ConfigOption<Boolean> SQL_OPTIMIZER_DATA_SKEW_DISTINCT_AGG_ENABLED =
key("sql.optimizer.data-skew.distinct-agg.enabled")
.defaultValue(false)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkFilterJoinRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkFilterJoinRule.java
new file mode 100644
index 0000000..dd5be0f
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkFilterJoinRule.java
@@ -0,0 +1,363 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.flink.table.plan.util.FlinkRelOptUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptRuleOperand;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.EquiJoin;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+
+import static org.apache.calcite.plan.RelOptUtil.conjunctions;
+
+/**
+ * This rules is copied from Calcite's {@link org.apache.calcite.rel.rules.FilterJoinRule}.
+ * Modification:
+ * - Use `FlinkRelOptUtil.classifyFilters` to support SEMI/ANTI join
+ * - TODO Handles the ON condition of anti-join can not be pushed down
+ */
+
+/**
+ * Planner rule that pushes filters above and
+ * within a join node into the join node and/or its children nodes.
+ */
+public abstract class FlinkFilterJoinRule extends RelOptRule {
+ /** Predicate that always returns true. With this predicate, every filter
+ * will be pushed into the ON clause. */
+ public static final Predicate TRUE_PREDICATE = (join, joinType, exp) -> true;
+
+ /** Rule that pushes predicates from a Filter into the Join below them. */
+ public static final FlinkFilterJoinRule FILTER_ON_JOIN =
+ new FlinkFilterIntoJoinRule(true, RelFactories.LOGICAL_BUILDER,
+ TRUE_PREDICATE);
+
+ /** Dumber version of {@link #FILTER_ON_JOIN}. Not intended for production
+ * use, but keeps some tests working for which {@code FILTER_ON_JOIN} is too
+ * smart. */
+ public static final FlinkFilterJoinRule DUMB_FILTER_ON_JOIN =
+ new FlinkFilterIntoJoinRule(false, RelFactories.LOGICAL_BUILDER,
+ TRUE_PREDICATE);
+
+ /** Rule that pushes predicates in a Join into the inputs to the join. */
+ public static final FlinkFilterJoinRule JOIN =
+ new FlinkJoinConditionPushRule(RelFactories.LOGICAL_BUILDER, TRUE_PREDICATE);
+
+ /** Whether to try to strengthen join-type. */
+ private final boolean smart;
+
+ /** Predicate that returns whether a filter is valid in the ON clause of a
+ * join for this particular kind of join. If not, Calcite will push it back to
+ * above the join. */
+ private final Predicate predicate;
+
+ //~ Constructors -----------------------------------------------------------
+
+ /**
+ * Creates a FilterProjectTransposeRule with an explicit root operand and
+ * factories.
+ */
+ protected FlinkFilterJoinRule(RelOptRuleOperand operand, String id,
+ boolean smart, RelBuilderFactory relBuilderFactory, Predicate predicate) {
+ super(operand, relBuilderFactory, "FlinkFilterJoinRule:" + id);
+ this.smart = smart;
+ this.predicate = Objects.requireNonNull(predicate);
+ }
+
+ /**
+ * Creates a FlinkFilterJoinRule with an explicit root operand and
+ * factories.
+ */
+ @Deprecated // to be removed before 2.0
+ protected FlinkFilterJoinRule(RelOptRuleOperand operand, String id,
+ boolean smart, RelFactories.FilterFactory filterFactory,
+ RelFactories.ProjectFactory projectFactory) {
+ this(operand, id, smart, RelBuilder.proto(filterFactory, projectFactory),
+ TRUE_PREDICATE);
+ }
+
+ /**
+ * Creates a FilterProjectTransposeRule with an explicit root operand and
+ * factories.
+ */
+ @Deprecated // to be removed before 2.0
+ protected FlinkFilterJoinRule(RelOptRuleOperand operand, String id,
+ boolean smart, RelFactories.FilterFactory filterFactory,
+ RelFactories.ProjectFactory projectFactory,
+ Predicate predicate) {
+ this(operand, id, smart, RelBuilder.proto(filterFactory, projectFactory),
+ predicate);
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ protected void perform(RelOptRuleCall call, Filter filter,
+ Join join) {
+ final List<RexNode> joinFilters =
+ RelOptUtil.conjunctions(join.getCondition());
+ final List<RexNode> origJoinFilters = com.google.common.collect.ImmutableList.copyOf(joinFilters);
+
+ // If there is only the joinRel,
+ // make sure it does not match a cartesian product joinRel
+ // (with "true" condition), otherwise this rule will be applied
+ // again on the new cartesian product joinRel.
+ if (filter == null && joinFilters.isEmpty()) {
+ return;
+ }
+
+ final List<RexNode> aboveFilters =
+ filter != null
+ ? conjunctions(filter.getCondition())
+ : new ArrayList<>();
+ final com.google.common.collect.ImmutableList<RexNode> origAboveFilters =
+ com.google.common.collect.ImmutableList.copyOf(aboveFilters);
+
+ // Simplify Outer Joins
+ JoinRelType joinType = join.getJoinType();
+ if (smart
+ && !origAboveFilters.isEmpty()
+ && join.getJoinType() != JoinRelType.INNER) {
+ joinType = FlinkRelOptUtil.simplifyJoin(join, origAboveFilters, joinType);
+ }
+
+ final List<RexNode> leftFilters = new ArrayList<>();
+ final List<RexNode> rightFilters = new ArrayList<>();
+
+ // TODO - add logic to derive additional filters. E.g., from
+ // (t1.a = 1 AND t2.a = 2) OR (t1.b = 3 AND t2.b = 4), you can
+ // derive table filters:
+ // (t1.a = 1 OR t1.b = 3)
+ // (t2.a = 2 OR t2.b = 4)
+
+ // Try to push down above filters. These are typically where clause
+ // filters. They can be pushed down if they are not on the NULL
+ // generating side.
+ boolean filterPushed = false;
+ if (FlinkRelOptUtil.classifyFilters(
+ join,
+ aboveFilters,
+ joinType,
+ !(join instanceof EquiJoin),
+ !joinType.generatesNullsOnLeft(),
+ !joinType.generatesNullsOnRight(),
+ joinFilters,
+ leftFilters,
+ rightFilters)) {
+ filterPushed = true;
+ }
+
+ // Move join filters up if needed
+ validateJoinFilters(aboveFilters, joinFilters, join, joinType);
+
+ // If no filter got pushed after validate, reset filterPushed flag
+ if (leftFilters.isEmpty()
+ && rightFilters.isEmpty()
+ && joinFilters.size() == origJoinFilters.size()) {
+ if (com.google.common.collect.Sets.newHashSet(joinFilters)
+ .equals(com.google.common.collect.Sets.newHashSet(origJoinFilters))) {
+ filterPushed = false;
+ }
+ }
+
+ // Try to push down filters in ON clause. A ON clause filter can only be
+ // pushed down if it does not affect the non-matching set, i.e. it is
+ // not on the side which is preserved.
+ if (FlinkRelOptUtil.classifyFilters(
+ join,
+ joinFilters,
+ joinType,
+ false,
+ !joinType.generatesNullsOnRight(),
+ !joinType.generatesNullsOnLeft(),
+ joinFilters,
+ leftFilters,
+ rightFilters)) {
+ filterPushed = true;
+ }
+
+ // if nothing actually got pushed and there is nothing leftover,
+ // then this rule is a no-op
+ if ((!filterPushed
+ && joinType == join.getJoinType())
+ || (joinFilters.isEmpty()
+ && leftFilters.isEmpty()
+ && rightFilters.isEmpty())) {
+ return;
+ }
+
+ // create Filters on top of the children if any filters were
+ // pushed to them
+ final RexBuilder rexBuilder = join.getCluster().getRexBuilder();
+ final RelBuilder relBuilder = call.builder();
+ final RelNode leftRel =
+ relBuilder.push(join.getLeft()).filter(leftFilters).build();
+ final RelNode rightRel =
+ relBuilder.push(join.getRight()).filter(rightFilters).build();
+
+ // create the new join node referencing the new children and
+ // containing its new join filters (if there are any)
+ final com.google.common.collect.ImmutableList<RelDataType> fieldTypes =
+ com.google.common.collect.ImmutableList.<RelDataType>builder()
+ .addAll(RelOptUtil.getFieldTypeList(leftRel.getRowType()))
+ .addAll(RelOptUtil.getFieldTypeList(rightRel.getRowType())).build();
+ final RexNode joinFilter =
+ RexUtil.composeConjunction(rexBuilder,
+ RexUtil.fixUp(rexBuilder, joinFilters, fieldTypes));
+
+ // If nothing actually got pushed and there is nothing leftover,
+ // then this rule is a no-op
+ if (joinFilter.isAlwaysTrue()
+ && leftFilters.isEmpty()
+ && rightFilters.isEmpty()
+ && joinType == join.getJoinType()) {
+ return;
+ }
+
+ RelNode newJoinRel =
+ join.copy(
+ join.getTraitSet(),
+ joinFilter,
+ leftRel,
+ rightRel,
+ joinType,
+ join.isSemiJoinDone());
+ call.getPlanner().onCopy(join, newJoinRel);
+ if (!leftFilters.isEmpty()) {
+ call.getPlanner().onCopy(filter, leftRel);
+ }
+ if (!rightFilters.isEmpty()) {
+ call.getPlanner().onCopy(filter, rightRel);
+ }
+
+ relBuilder.push(newJoinRel);
+
+ // Create a project on top of the join if some of the columns have become
+ // NOT NULL due to the join-type getting stricter.
+ relBuilder.convert(join.getRowType(), false);
+
+ // create a FilterRel on top of the join if needed
+ relBuilder.filter(
+ RexUtil.fixUp(rexBuilder, aboveFilters,
+ RelOptUtil.getFieldTypeList(relBuilder.peek().getRowType())));
+
+ call.transformTo(relBuilder.build());
+ }
+
+ /**
+ * Validates that target execution framework can satisfy join filters.
+ *
+ * <p>If the join filter cannot be satisfied (for example, if it is
+ * {@code l.c1 > r.c2} and the join only supports equi-join), removes the
+ * filter from {@code joinFilters} and adds it to {@code aboveFilters}.
+ *
+ * <p>The default implementation does nothing; i.e. the join can handle all
+ * conditions.
+ *
+ * @param aboveFilters Filter above Join
+ * @param joinFilters Filters in join condition
+ * @param join Join
+ * @param joinType JoinRelType could be different from type in Join due to
+ * outer join simplification.
+ */
+ protected void validateJoinFilters(List<RexNode> aboveFilters,
+ List<RexNode> joinFilters, Join join, JoinRelType joinType) {
+ final Iterator<RexNode> filterIter = joinFilters.iterator();
+ while (filterIter.hasNext()) {
+ RexNode exp = filterIter.next();
+ if (!predicate.apply(join, joinType, exp)) {
+ aboveFilters.add(exp);
+ filterIter.remove();
+ }
+ }
+ }
+
+ /** Rule that pushes parts of the join condition to its inputs. */
+ public static class FlinkJoinConditionPushRule extends FlinkFilterJoinRule {
+ public FlinkJoinConditionPushRule(RelBuilderFactory relBuilderFactory,
+ Predicate predicate) {
+ super(RelOptRule.operand(Join.class, RelOptRule.any()),
+ "FlinkFilterJoinRule:no-filter", true, relBuilderFactory,
+ predicate);
+ }
+
+ @Deprecated // to be removed before 2.0
+ public FlinkJoinConditionPushRule(RelFactories.FilterFactory filterFactory,
+ RelFactories.ProjectFactory projectFactory, Predicate predicate) {
+ this(RelBuilder.proto(filterFactory, projectFactory), predicate);
+ }
+
+ @Override public void onMatch(RelOptRuleCall call) {
+ Join join = call.rel(0);
+ perform(call, null, join);
+ }
+ }
+
+ /** Rule that tries to push filter expressions into a join
+ * condition and into the inputs of the join. */
+ public static class FlinkFilterIntoJoinRule extends FlinkFilterJoinRule {
+ public FlinkFilterIntoJoinRule(boolean smart,
+ RelBuilderFactory relBuilderFactory, Predicate predicate) {
+ super(
+ operand(Filter.class,
+ operand(Join.class, RelOptRule.any())),
+ "FlinkFilterJoinRule:filter", smart, relBuilderFactory,
+ predicate);
+ }
+
+ @Deprecated // to be removed before 2.0
+ public FlinkFilterIntoJoinRule(boolean smart,
+ RelFactories.FilterFactory filterFactory,
+ RelFactories.ProjectFactory projectFactory,
+ Predicate predicate) {
+ this(smart, RelBuilder.proto(filterFactory, projectFactory), predicate);
+ }
+
+ @Override public void onMatch(RelOptRuleCall call) {
+ Filter filter = call.rel(0);
+ Join join = call.rel(1);
+ perform(call, filter, join);
+ }
+ }
+
+ /** Predicate that returns whether a filter is valid in the ON clause of a
+ * join for this particular kind of join. If not, Calcite will push it back to
+ * above the join. */
+ public interface Predicate {
+ boolean apply(Join join, JoinRelType joinType, RexNode exp);
+ }
+}
+
+// End FlinkFilterJoinRule.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java
new file mode 100644
index 0000000..7e6b960
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.flink.table.plan.util.FlinkRelOptUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+
+/**
+ * This rules is copied from Calcite's {@link org.apache.calcite.rel.rules.JoinPushExpressionsRule}.
+ * Modification:
+ * - Supports SEMI/ANTI join using {@link org.apache.flink.table.plan.util.FlinkRelOptUtil#pushDownJoinConditions}
+ */
+
+/**
+ * Planner rule that pushes down expressions in "equal" join condition.
+ *
+ * <p>For example, given
+ * "emp JOIN dept ON emp.deptno + 1 = dept.deptno", adds a project above
+ * "emp" that computes the expression
+ * "emp.deptno + 1". The resulting join condition is a simple combination
+ * of AND, equals, and input fields, plus the remaining non-equal conditions.
+ */
+public class FlinkJoinPushExpressionsRule extends RelOptRule {
+
+ public static final FlinkJoinPushExpressionsRule INSTANCE =
+ new FlinkJoinPushExpressionsRule(Join.class, RelFactories.LOGICAL_BUILDER);
+
+ /** Creates a JoinPushExpressionsRule. */
+ public FlinkJoinPushExpressionsRule(Class<? extends Join> clazz,
+ RelBuilderFactory relBuilderFactory) {
+ super(operand(clazz, any()), relBuilderFactory, null);
+ }
+
+ @Deprecated // to be removed before 2.0
+ public FlinkJoinPushExpressionsRule(Class<? extends Join> clazz,
+ RelFactories.ProjectFactory projectFactory) {
+ this(clazz, RelBuilder.proto(projectFactory));
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ Join join = call.rel(0);
+
+ // Push expression in join condition into Project below Join.
+ RelNode newJoin = FlinkRelOptUtil.pushDownJoinConditions(join, call.builder());
+
+ // If the join is the same, we bail out
+ if (newJoin instanceof Join) {
+ final RexNode newCondition = ((Join) newJoin).getCondition();
+ if (join.getCondition().equals(newCondition)) {
+ return;
+ }
+ }
+
+ call.transformTo(newJoin);
+ }
+}
+
+// End FlinkJoinPushExpressionsRule.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkProjectJoinTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkProjectJoinTransposeRule.java
new file mode 100644
index 0000000..2a3bf9a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkProjectJoinTransposeRule.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.rules.PushProjector;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilderFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.ProjectJoinTransposeRule}.
+ * Modification:
+ * - Does not match SEMI/ANTI join now
+ */
+
+/**
+ * Planner rule that pushes a {@link org.apache.calcite.rel.core.Project}
+ * past a {@link org.apache.calcite.rel.core.Join}
+ * by splitting the projection into a projection on top of each child of
+ * the join.
+ */
+public class FlinkProjectJoinTransposeRule extends RelOptRule {
+ public static final FlinkProjectJoinTransposeRule INSTANCE =
+ new FlinkProjectJoinTransposeRule(expr -> true,
+ RelFactories.LOGICAL_BUILDER);
+
+ //~ Instance fields --------------------------------------------------------
+
+ /**
+ * Condition for expressions that should be preserved in the projection.
+ */
+ private final PushProjector.ExprCondition preserveExprCondition;
+
+ //~ Constructors -----------------------------------------------------------
+
+ /**
+ * Creates a ProjectJoinTransposeRule with an explicit condition.
+ *
+ * @param preserveExprCondition Condition for expressions that should be
+ * preserved in the projection
+ */
+ public FlinkProjectJoinTransposeRule(
+ PushProjector.ExprCondition preserveExprCondition,
+ RelBuilderFactory relFactory) {
+ super(
+ operand(Project.class,
+ operand(Join.class, any())),
+ relFactory, null);
+ this.preserveExprCondition = preserveExprCondition;
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ // implement RelOptRule
+ public void onMatch(RelOptRuleCall call) {
+ Project origProj = call.rel(0);
+ final Join join = call.rel(1);
+
+ if (!join.getJoinType().projectsRight()) {
+ return; // TODO: support SEMI/ANTI join later
+ }
+ // locate all fields referenced in the projection and join condition;
+ // determine which inputs are referenced in the projection and
+ // join condition; if all fields are being referenced and there are no
+ // special expressions, no point in proceeding any further
+ PushProjector pushProject =
+ new PushProjector(
+ origProj,
+ join.getCondition(),
+ join,
+ preserveExprCondition,
+ call.builder());
+ if (pushProject.locateAllRefs()) {
+ return;
+ }
+
+ // create left and right projections, projecting only those
+ // fields referenced on each side
+ RelNode leftProjRel =
+ pushProject.createProjectRefsAndExprs(
+ join.getLeft(),
+ true,
+ false);
+ RelNode rightProjRel =
+ pushProject.createProjectRefsAndExprs(
+ join.getRight(),
+ true,
+ true);
+
+ // convert the join condition to reference the projected columns
+ RexNode newJoinFilter = null;
+ int[] adjustments = pushProject.getAdjustments();
+ if (join.getCondition() != null) {
+ List<RelDataTypeField> projJoinFieldList = new ArrayList<>();
+ projJoinFieldList.addAll(
+ join.getSystemFieldList());
+ projJoinFieldList.addAll(
+ leftProjRel.getRowType().getFieldList());
+ projJoinFieldList.addAll(
+ rightProjRel.getRowType().getFieldList());
+ newJoinFilter =
+ pushProject.convertRefsAndExprs(
+ join.getCondition(),
+ projJoinFieldList,
+ adjustments);
+ }
+
+ // create a new join with the projected children
+ Join newJoinRel =
+ join.copy(
+ join.getTraitSet(),
+ newJoinFilter,
+ leftProjRel,
+ rightProjRel,
+ join.getJoinType(),
+ join.isSemiJoinDone());
+
+ // put the original project on top of the join, converting it to
+ // reference the modified projection list
+ RelNode topProject =
+ pushProject.createNewProject(newJoinRel, adjustments);
+
+ call.transformTo(topProject);
+ }
+}
+
+// End FlinkProjectJoinTransposeRule.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/SubQueryDecorrelator.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/SubQueryDecorrelator.java
new file mode 100644
index 0000000..a404e3f
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/SubQueryDecorrelator.java
@@ -0,0 +1,1445 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.flink.table.calcite.FlinkRelBuilder;
+import org.apache.flink.table.plan.util.FlinkRelOptUtil;
+import org.apache.flink.table.plan.util.FlinkRexUtil;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.hep.HepRelVertex;
+import org.apache.calcite.rel.RelCollation;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelShuttleImpl;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.SetOp;
+import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.core.Values;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.logical.LogicalCorrelate;
+import org.apache.calcite.rel.logical.LogicalFilter;
+import org.apache.calcite.rel.logical.LogicalIntersect;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.logical.LogicalMinus;
+import org.apache.calcite.rel.logical.LogicalProject;
+import org.apache.calcite.rel.logical.LogicalSort;
+import org.apache.calcite.rel.logical.LogicalUnion;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexOver;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexSubQuery;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexVisitor;
+import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.Bug;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.ReflectUtil;
+import org.apache.calcite.util.ReflectiveVisitor;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mappings;
+
+import javax.annotation.Nonnull;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import java.util.TreeSet;
+
+/**
+ * SubQueryDecorrelator finds all correlated expressions in a SubQuery,
+ * and gets an equivalent non-correlated relational expression tree and correlation conditions.
+ *
+ * <p>The Basic idea of SubQueryDecorrelator is from {@link org.apache.calcite.sql2rel.RelDecorrelator},
+ * however there are differences between them:
+ * 1. This class works with {@link RexSubQuery}, while RelDecorrelator works with {@link LogicalCorrelate}.
+ * 2. This class will get an equivalent non-correlated expressions tree and correlation conditions,
+ * while RelDecorrelator will replace all correlated expressions with non-correlated expressions that are produced
+ * from joining the RelNode.
+ * 3. This class supports both equi and non-equi correlation conditions,
+ * while RelDecorrelator only supports equi correlation conditions.
+ */
+public class SubQueryDecorrelator extends RelShuttleImpl {
+ private final SubQueryRelDecorrelator decorrelator;
+ private final RelBuilder relBuilder;
+
+ // map a SubQuery to an equivalent RelNode and correlation-condition pair
+ private final Map<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap = new HashMap<>();
+
+ private SubQueryDecorrelator(SubQueryRelDecorrelator decorrelator, RelBuilder relBuilder) {
+ this.decorrelator = decorrelator;
+ this.relBuilder = relBuilder;
+ }
+
+ /**
+ * Decorrelates a subquery.
+ *
+ * <p>This is the main entry point to {@code SubQueryDecorrelator}.
+ *
+ * @param rootRel The node which has SubQuery.
+ * @return Decorrelate result.
+ */
+ public static Result decorrelateQuery(RelNode rootRel) {
+ int maxCnfNodeCount = FlinkRelOptUtil.getMaxCnfNodeCount(rootRel);
+
+ final CorelMapBuilder builder = new CorelMapBuilder(maxCnfNodeCount);
+ final CorelMap corelMap = builder.build(rootRel);
+ if (builder.hasNestedCorScope || builder.hasUnsupportedCorCondition) {
+ return null;
+ }
+
+ if (!corelMap.hasCorrelation()) {
+ return Result.EMPTY;
+ }
+
+ RelOptCluster cluster = rootRel.getCluster();
+ RelBuilder relBuilder = new FlinkRelBuilder(cluster.getPlanner().getContext(), cluster, null);
+ RexBuilder rexBuilder = cluster.getRexBuilder();
+
+ final SubQueryDecorrelator decorrelator = new SubQueryDecorrelator(
+ new SubQueryRelDecorrelator(corelMap, relBuilder, rexBuilder, maxCnfNodeCount),
+ relBuilder);
+ rootRel.accept(decorrelator);
+
+ return new Result(decorrelator.subQueryMap);
+ }
+
+ @Override
+ protected RelNode visitChild(RelNode parent, int i, RelNode input) {
+ return super.visitChild(parent, i, stripHep(input));
+ }
+
+ @Override
+ public RelNode visit(final LogicalFilter filter) {
+ try {
+ stack.push(filter);
+ filter.getCondition().accept(handleSubQuery(filter));
+ } finally {
+ stack.pop();
+ }
+ return super.visit(filter);
+ }
+
+ private RexVisitorImpl<Void> handleSubQuery(final RelNode rel) {
+ return new RexVisitorImpl<Void>(true) {
+
+ @Override
+ public Void visitSubQuery(RexSubQuery subQuery) {
+ RelNode newRel = subQuery.rel;
+ if (subQuery.getKind() == SqlKind.IN) {
+ newRel = addProjectionForIn(subQuery.rel);
+ }
+ final Frame frame = decorrelator.getInvoke(newRel);
+ if (frame != null && frame.c != null) {
+
+ Frame target = frame;
+ if (subQuery.getKind() == SqlKind.EXISTS) {
+ target = addProjectionForExists(frame);
+ }
+
+ final DecorrelateRexShuttle shuttle = new DecorrelateRexShuttle(
+ rel.getRowType(),
+ target.r.getRowType(),
+ rel.getVariablesSet());
+
+ final RexNode newCondition = target.c.accept(shuttle);
+ Pair<RelNode, RexNode> newNodeAndCondition = new Pair<>(target.r, newCondition);
+ subQueryMap.put(subQuery, newNodeAndCondition);
+ }
+ return null;
+ }
+ };
+ }
+
+ /**
+ * Adds Projection to adjust the field index for join condition.
+ *
+ * <p>e.g. SQL: SELECT * FROM l WHERE b IN (SELECT COUNT(*) FROM r WHERE l.c = r.f
+ * the rel in SubQuery is `LogicalAggregate(group=[{}], EXPR$1=[COUNT()])`.
+ * After decorrelated, it was changed to `LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])`,
+ * and the output index of `COUNT()` was changed from 0 to 1.
+ * So, add a project (`LogicalProject(EXPR$0=[$1], f=[$0])`) to adjust output fields order.
+ */
+ private RelNode addProjectionForIn(RelNode relNode) {
+ if (relNode instanceof LogicalProject) {
+ return relNode;
+ }
+
+ RelDataType rowType = relNode.getRowType();
+ final List<RexNode> projects = new ArrayList<>();
+ for (int i = 0; i < rowType.getFieldCount(); ++i) {
+ projects.add(RexInputRef.of(i, rowType));
+ }
+
+ relBuilder.clear();
+ relBuilder.push(relNode);
+ relBuilder.project(projects, rowType.getFieldNames(), true);
+ return relBuilder.build();
+ }
+
+ /**
+ * Adds Projection to choose the fields used by join condition.
+ */
+ private Frame addProjectionForExists(Frame frame) {
+ final List<Integer> corIndices = new ArrayList<>(frame.getCorInputRefIndices());
+ final RelNode rel = frame.r;
+ final RelDataType rowType = rel.getRowType();
+ if (corIndices.size() == rowType.getFieldCount()) {
+ // no need projection
+ return frame;
+ }
+
+ final List<RexNode> projects = new ArrayList<>();
+ final Map<Integer, Integer> mapInputToOutput = new HashMap<>();
+
+ Collections.sort(corIndices);
+ int newPos = 0;
+ for (int index : corIndices) {
+ projects.add(RexInputRef.of(index, rowType));
+ mapInputToOutput.put(index, newPos++);
+ }
+
+ relBuilder.clear();
+ relBuilder.push(frame.r);
+ relBuilder.project(projects);
+ final RelNode newProject = relBuilder.build();
+ final RexNode newCondition = adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType());
+
+ // There is no old RelNode corresponding to newProject, so oldToNewOutputs is empty.
+ return new Frame(rel, newProject, newCondition, new HashMap<>());
+ }
+
+ private static RelNode stripHep(RelNode rel) {
+ if (rel instanceof HepRelVertex) {
+ HepRelVertex hepRelVertex = (HepRelVertex) rel;
+ rel = hepRelVertex.getCurrentRel();
+ }
+ return rel;
+ }
+
+ private static void analyzeCorConditions(
+ final Set<CorrelationId> variableSet,
+ final RexNode condition,
+ final RexBuilder rexBuilder,
+ final int maxCnfNodeCount,
+ final List<RexNode> corConditions,
+ final List<RexNode> nonCorConditions,
+ final List<RexNode> unsupportedCorConditions) {
+ // converts the expanded expression to conjunctive normal form,
+ // like "(a AND b) OR c" will be converted to "(a OR c) AND (b OR c)"
+ final RexNode cnf = FlinkRexUtil.toCnf(rexBuilder, maxCnfNodeCount, condition);
+ // converts the cnf condition to a list of AND conditions
+ final List<RexNode> conjunctions = RelOptUtil.conjunctions(cnf);
+ // `true` for RexNode is supported correlation condition,
+ // `false` for RexNode is unsupported correlation condition,
+ // `null` for RexNode is not a correlation condition.
+ final RexVisitorImpl<Boolean> visitor = new RexVisitorImpl<Boolean>(true) {
+
+ @Override
+ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
+ final RexNode ref = fieldAccess.getReferenceExpr();
+ if (ref instanceof RexCorrelVariable) {
+ return visitCorrelVariable((RexCorrelVariable) ref);
+ } else {
+ return super.visitFieldAccess(fieldAccess);
+ }
+ }
+
+ @Override
+ public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) {
+ return variableSet.contains(correlVariable.id);
+ }
+
+ @Override
+ public Boolean visitSubQuery(RexSubQuery subQuery) {
+ final List<Boolean> result = new ArrayList<>();
+ for (RexNode operand : subQuery.operands) {
+ result.add(operand.accept(this));
+ }
+ // we do not support nested correlation variables in SubQuery, such as:
+ // select * from t1 where exists(select * from t2 where t1.a = t2.c and t1.b in (select t3.d from t3)
+ if (result.contains(true) || result.contains(false)) {
+ return false;
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public Boolean visitCall(RexCall call) {
+ final List<Boolean> result = new ArrayList<>();
+ for (RexNode operand : call.operands) {
+ result.add(operand.accept(this));
+ }
+ if (result.contains(false)) {
+ return false;
+ } else if (result.contains(true)) {
+ // TODO supports correlation variable with OR
+ // return call.op.getKind() != SqlKind.OR || !result.contains(null);
+ return call.op.getKind() != SqlKind.OR;
+ } else {
+ return null;
+ }
+ }
+ };
+
+ for (RexNode c : conjunctions) {
+ Boolean r = c.accept(visitor);
+ if (r == null) {
+ nonCorConditions.add(c);
+ } else if (r) {
+ corConditions.add(c);
+ } else {
+ unsupportedCorConditions.add(c);
+ }
+ }
+ }
+
+ /**
+ * Adjust the condition's field indices according to mapOldToNewIndex.
+ *
+ * @param c The condition to be adjusted.
+ * @param mapOldToNewIndex A map containing the mapping the old field indices to new field indices.
+ * @param rowType The row type of the new output.
+ * @return Return new condition with new field indices.
+ */
+ private static RexNode adjustInputRefs(
+ final RexNode c,
+ final Map<Integer, Integer> mapOldToNewIndex,
+ final RelDataType rowType) {
+ return c.accept(new RexShuttle() {
+ @Override
+ public RexNode visitInputRef(RexInputRef inputRef) {
+ assert mapOldToNewIndex.containsKey(inputRef.getIndex());
+ int newIndex = mapOldToNewIndex.get(inputRef.getIndex());
+ final RexInputRef ref = RexInputRef.of(newIndex, rowType);
+ if (ref.getIndex() == inputRef.getIndex() && ref.getType() == inputRef.getType()) {
+ return inputRef; // re-use old object, to prevent needless expr cloning
+ } else {
+ return ref;
+ }
+ }
+ });
+ }
+
+ private static class DecorrelateRexShuttle extends RexShuttle {
+ private final RelDataType leftRowType;
+ private final RelDataType rightRowType;
+ private final Set<CorrelationId> variableSet;
+
+ private DecorrelateRexShuttle(
+ RelDataType leftRowType,
+ RelDataType rightRowType,
+ Set<CorrelationId> variableSet) {
+ this.leftRowType = leftRowType;
+ this.rightRowType = rightRowType;
+ this.variableSet = variableSet;
+ }
+
+ @Override
+ public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+ final RexNode ref = fieldAccess.getReferenceExpr();
+ if (ref instanceof RexCorrelVariable) {
+ final RexCorrelVariable var = (RexCorrelVariable) ref;
+ assert variableSet.contains(var.id);
+ final RelDataTypeField field = fieldAccess.getField();
+ return new RexInputRef(field.getIndex(), field.getType());
+ } else {
+ return super.visitFieldAccess(fieldAccess);
+ }
+ }
+
+ @Override
+ public RexNode visitInputRef(RexInputRef inputRef) {
+ assert inputRef.getIndex() < rightRowType.getFieldCount();
+ int newIndex = inputRef.getIndex() + leftRowType.getFieldCount();
+ return new RexInputRef(newIndex, inputRef.getType());
+ }
+ }
+
+ /**
+ * Pull out all correlation conditions from a given subquery to top level,
+ * and rebuild the subquery rel tree without correlation conditions.
+ *
+ * <p>`public` is for reflection.
+ * We use ReflectiveVisitor instead of RelShuttle because RelShuttle returns RelNode.
+ */
+ public static class SubQueryRelDecorrelator implements ReflectiveVisitor {
+ // map built during translation
+ private final CorelMap cm;
+ private final RelBuilder relBuilder;
+ private final RexBuilder rexBuilder;
+ private final ReflectUtil.MethodDispatcher<Frame> dispatcher =
+ ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel", RelNode.class);
+ private final int maxCnfNodeCount;
+
+ SubQueryRelDecorrelator(CorelMap cm, RelBuilder relBuilder, RexBuilder rexBuilder, int maxCnfNodeCount) {
+ this.cm = cm;
+ this.relBuilder = relBuilder;
+ this.rexBuilder = rexBuilder;
+ this.maxCnfNodeCount = maxCnfNodeCount;
+ }
+
+ Frame getInvoke(RelNode r) {
+ return dispatcher.invoke(r);
+ }
+
+ /**
+ * Rewrite LogicalProject.
+ *
+ * <p>Rewrite logic:
+ * Pass along any correlated variables coming from the input.
+ *
+ * @param rel the project rel to rewrite
+ */
+ public Frame decorrelateRel(LogicalProject rel) {
+ final RelNode oldInput = rel.getInput();
+ Frame frame = getInvoke(oldInput);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ final List<RexNode> oldProjects = rel.getProjects();
+ final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
+ final RelNode newInput = frame.r;
+
+ // Project projects the original expressions,
+ // plus any correlated variables the input wants to pass along.
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+
+ // If this Project has correlated reference, produce the correlated variables in the new output.
+ // TODO Currently, correlation in projection is not supported.
+ assert !cm.mapRefRelToCorRef.containsKey(rel);
+
+ final Map<Integer, Integer> mapInputToOutput = new HashMap<>();
+ final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+ // Project projects the original expressions
+ int newPos;
+ for (newPos = 0; newPos < oldProjects.size(); newPos++) {
+ RexNode project = adjustInputRefs(
+ oldProjects.get(newPos), frame.oldToNewOutputs, newInput.getRowType());
+ projects.add(newPos, Pair.of(project, relOutput.get(newPos).getName()));
+ mapOldToNewOutputs.put(newPos, newPos);
+ if (project instanceof RexInputRef) {
+ mapInputToOutput.put(((RexInputRef) project).getIndex(), newPos);
+ }
+ }
+
+ if (frame.c != null) {
+ // Project any correlated variables the input wants to pass along.
+ final ImmutableBitSet corInputIndices = RelOptUtil.InputFinder.bits(frame.c);
+ final RelDataType inputRowType = newInput.getRowType();
+ for (int inputIndex : corInputIndices.toList()) {
+ if (!mapInputToOutput.containsKey(inputIndex)) {
+ projects.add(newPos, Pair.of(
+ RexInputRef.of(inputIndex, inputRowType),
+ inputRowType.getFieldNames().get(inputIndex)));
+ mapInputToOutput.put(inputIndex, newPos);
+ newPos++;
+ }
+ }
+ }
+ RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
+
+ final RexNode newCorCondition;
+ if (frame.c != null) {
+ newCorCondition = adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType());
+ } else {
+ newCorCondition = null;
+ }
+
+ return new Frame(rel, newProject, newCorCondition, mapOldToNewOutputs);
+ }
+
+ /**
+ * Rewrite LogicalFilter.
+ *
+ * <p>Rewrite logic:
+ * 1. If a Filter references a correlated field in its filter condition,
+ * rewrite the Filter references only non-correlated fields,
+ * and the condition references correlated fields will be push to it's output.
+ * 2. If Filter does not reference correlated variables,
+ * simply rewrite the filter condition using new input.
+ *
+ * @param rel the filter rel to rewrite
+ */
+ public Frame decorrelateRel(LogicalFilter rel) {
+ final RelNode oldInput = rel.getInput();
+ Frame frame = getInvoke(oldInput);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ // Conditions reference only correlated fields
+ final List<RexNode> corConditions = new ArrayList<>();
+ // Conditions do not reference any correlated fields
+ final List<RexNode> nonCorConditions = new ArrayList<>();
+ // Conditions reference correlated fields, but not supported now
+ final List<RexNode> unsupportedCorConditions = new ArrayList<>();
+
+ analyzeCorConditions(
+ cm.mapSubQueryNodeToCorSet.get(rel),
+ rel.getCondition(),
+ rexBuilder,
+ maxCnfNodeCount,
+ corConditions,
+ nonCorConditions,
+ unsupportedCorConditions);
+ assert unsupportedCorConditions.isEmpty();
+
+ final RexNode remainingCondition = RexUtil.composeConjunction(rexBuilder, nonCorConditions, false);
+
+ // Using LogicalFilter.create instead of RelBuilder.filter to create Filter
+ // because RelBuilder.filter method does not have VariablesSet arg.
+ final LogicalFilter newFilter = LogicalFilter.create(
+ frame.r,
+ remainingCondition,
+ com.google.common.collect.ImmutableSet.copyOf(rel.getVariablesSet()));
+
+ // Adds input's correlation condition
+ if (frame.c != null) {
+ corConditions.add(frame.c);
+ }
+
+ final RexNode corCondition = RexUtil.composeConjunction(rexBuilder, corConditions, true);
+ // Filter does not change the input ordering.
+ // All corVars produced by filter will have the same output positions in the input rel.
+ return new Frame(rel, newFilter, corCondition, frame.oldToNewOutputs);
+ }
+
+ /**
+ * Rewrites a {@link LogicalAggregate}.
+ *
+ * <p>Rewrite logic:
+ * 1. Permute the group by keys to the front.
+ * 2. If the input of an aggregate produces correlated variables, add them to the group list.
+ * 3. Change aggCalls to reference the new project.
+ *
+ * @param rel Aggregate to rewrite
+ */
+ public Frame decorrelateRel(LogicalAggregate rel) {
+ if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
+ throw new AssertionError(Bug.CALCITE_461_FIXED);
+ }
+
+ // Aggregate itself should not reference corVars.
+ assert !cm.mapRefRelToCorRef.containsKey(rel);
+
+ final RelNode oldInput = rel.getInput();
+ final Frame frame = getInvoke(oldInput);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ final RelNode newInput = frame.r;
+ // map from newInput
+ final Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
+ final int oldGroupKeyCount = rel.getGroupSet().cardinality();
+
+ // Project projects the original expressions,
+ // plus any correlated variables the input wants to pass along.
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+ final List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
+
+ // oldInput has the original group by keys in the front.
+ final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
+ int newPos = 0;
+ for (int i = 0; i < oldGroupKeyCount; i++) {
+ final RexLiteral constant = projectedLiteral(newInput, i);
+ if (constant != null) {
+ // Exclude constants. Aggregate({true}) occurs because Aggregate({})
+ // would generate 1 row even when applied to an empty table.
+ omittedConstants.put(i, constant);
+ continue;
+ }
+
+ int newInputPos = frame.oldToNewOutputs.get(i);
+ projects.add(newPos, RexInputRef.of2(newInputPos, newInputOutput));
+ mapNewInputToProjOutputs.put(newInputPos, newPos);
+ newPos++;
+ }
+
+ if (frame.c != null) {
+ // If input produces correlated variables, move them to the front,
+ // right after any existing GROUP BY fields.
+
+ // Now add the corVars from the input, starting from position oldGroupKeyCount.
+ for (Integer index : frame.getCorInputRefIndices()) {
+ if (!mapNewInputToProjOutputs.containsKey(index)) {
+ projects.add(newPos, RexInputRef.of2(index, newInputOutput));
+ mapNewInputToProjOutputs.put(index, newPos);
+ newPos++;
+ }
+ }
+ }
+
+ // add the remaining fields
+ final int newGroupKeyCount = newPos;
+ for (int i = 0; i < newInputOutput.size(); i++) {
+ if (!mapNewInputToProjOutputs.containsKey(i)) {
+ projects.add(newPos, RexInputRef.of2(i, newInputOutput));
+ mapNewInputToProjOutputs.put(i, newPos);
+ newPos++;
+ }
+ }
+
+ assert newPos == newInputOutput.size();
+
+ // This Project will be what the old input maps to,
+ // replacing any previous mapping from old input).
+ final RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
+
+ final RexNode newCondition;
+ if (frame.c != null) {
+ newCondition = adjustInputRefs(frame.c, mapNewInputToProjOutputs, newProject.getRowType());
+ } else {
+ newCondition = null;
+ }
+
+ // update mappings:
+ // oldInput ----> newInput
+ //
+ // newProject
+ // |
+ // oldInput ----> newInput
+ //
+ // is transformed to
+ //
+ // oldInput ----> newProject
+ // |
+ // newInput
+
+ final Map<Integer, Integer> combinedMap = new HashMap<>();
+ final Map<Integer, Integer> oldToNewOutputs = new HashMap<>();
+ final List<Integer> originalGrouping = rel.getGroupSet().toList();
+ for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
+ final Integer newIndex = mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos));
+ combinedMap.put(oldInputPos, newIndex);
+ // mapping grouping fields
+ if (originalGrouping.contains(oldInputPos)) {
+ oldToNewOutputs.put(oldInputPos, newIndex);
+ }
+ }
+
+ // now it's time to rewrite the Aggregate
+ final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
+ final List<AggregateCall> newAggCalls = new ArrayList<>();
+ final List<AggregateCall> oldAggCalls = rel.getAggCallList();
+
+ for (AggregateCall oldAggCall : oldAggCalls) {
+ final List<Integer> oldAggArgs = oldAggCall.getArgList();
+ final List<Integer> aggArgs = new ArrayList<>();
+
+ // Adjust the Aggregate argument positions.
+ // Note Aggregate does not change input ordering, so the input
+ // output position mapping can be used to derive the new positions
+ // for the argument.
+ for (int oldPos : oldAggArgs) {
+ aggArgs.add(combinedMap.get(oldPos));
+ }
+ final int filterArg = oldAggCall.filterArg < 0
+ ? oldAggCall.filterArg
+ : combinedMap.get(oldAggCall.filterArg);
+
+ newAggCalls.add(
+ oldAggCall.adaptTo(
+ newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
+ }
+
+ relBuilder.push(LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
+
+ if (!omittedConstants.isEmpty()) {
+ final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
+ for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.entrySet()) {
+ postProjects.add(mapNewInputToProjOutputs.get(entry.getKey()), entry.getValue());
+ }
+ relBuilder.project(postProjects);
+ }
+
+ // mapping aggCall output fields
+ for (int i = 0; i < oldAggCalls.size(); ++i) {
+ oldToNewOutputs.put(oldGroupKeyCount + i, newGroupKeyCount + omittedConstants.size() + i);
+ }
+
+ // Aggregate does not change input ordering so corVars will be
+ // located at the same position as the input newProject.
+ return new Frame(rel, relBuilder.build(), newCondition, oldToNewOutputs);
+ }
+
+ /**
+ * Rewrite LogicalJoin.
+ *
+ * <p>Rewrite logic:
+ * 1. rewrite join condition.
+ * 2. map output positions and produce corVars if any.
+ *
+ * @param rel Join
+ */
+ public Frame decorrelateRel(LogicalJoin rel) {
+ final RelNode oldLeft = rel.getInput(0);
+ final RelNode oldRight = rel.getInput(1);
+
+ final Frame leftFrame = getInvoke(oldLeft);
+ final Frame rightFrame = getInvoke(oldRight);
+
+ if (leftFrame == null || rightFrame == null) {
+ // If any input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+
+ switch (rel.getJoinType()) {
+ case LEFT:
+ assert rightFrame.c == null;
+ break;
+ case RIGHT:
+ assert leftFrame.c == null;
+ break;
+ case FULL:
+ assert leftFrame.c == null && rightFrame.c == null;
+ break;
+ default:
+ break;
+ }
+
+ final int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
+ final int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
+ final int oldRightFieldCount = oldRight.getRowType().getFieldCount();
+ assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount;
+
+ final RexNode newJoinCondition = adjustJoinCondition(
+ rel.getCondition(),
+ oldLeftFieldCount,
+ newLeftFieldCount,
+ leftFrame.oldToNewOutputs,
+ rightFrame.oldToNewOutputs);
+
+ final RelNode newJoin = LogicalJoin.create(
+ leftFrame.r, rightFrame.r, newJoinCondition, rel.getVariablesSet(), rel.getJoinType());
+
+ // Create the mapping between the output of the old correlation rel and the new join rel
+ final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+ // Left input positions are not changed.
+ mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);
+
+ // Right input positions are shifted by newLeftFieldCount.
+ for (int i = 0; i < oldRightFieldCount; i++) {
+ mapOldToNewOutputs.put(i + oldLeftFieldCount, rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount);
+ }
+
+ final List<RexNode> corConditions = new ArrayList<>();
+ if (leftFrame.c != null) {
+ corConditions.add(leftFrame.c);
+ }
+ if (rightFrame.c != null) {
+ // Right input positions are shifted by newLeftFieldCount.
+ final Map<Integer, Integer> rightMapOldToNewOutputs = new HashMap<>();
+ for (int index : rightFrame.getCorInputRefIndices()) {
+ rightMapOldToNewOutputs.put(index, index + newLeftFieldCount);
+ }
+ final RexNode newRightCondition = adjustInputRefs(
+ rightFrame.c, rightMapOldToNewOutputs, newJoin.getRowType());
+ corConditions.add(newRightCondition);
+ }
+
+ final RexNode newCondition = RexUtil.composeConjunction(rexBuilder, corConditions, true);
+ return new Frame(rel, newJoin, newCondition, mapOldToNewOutputs);
+ }
+
+ private RexNode adjustJoinCondition(
+ final RexNode joinCondition,
+ final int oldLeftFieldCount,
+ final int newLeftFieldCount,
+ final Map<Integer, Integer> leftOldToNewOutputs,
+ final Map<Integer, Integer> rightOldToNewOutputs) {
+ return joinCondition.accept(new RexShuttle() {
+ @Override
+ public RexNode visitInputRef(RexInputRef inputRef) {
+ int oldIndex = inputRef.getIndex();
+ final int newIndex;
+ if (oldIndex < oldLeftFieldCount) {
+ // field from left
+ assert leftOldToNewOutputs.containsKey(oldIndex);
+ newIndex = leftOldToNewOutputs.get(oldIndex);
+ } else {
+ // field from right
+ oldIndex = oldIndex - oldLeftFieldCount;
+ assert rightOldToNewOutputs.containsKey(oldIndex);
+ newIndex = rightOldToNewOutputs.get(oldIndex) + newLeftFieldCount;
+ }
+ return new RexInputRef(newIndex, inputRef.getType());
+ }
+ });
+ }
+
+ /**
+ * Rewrite Sort.
+ *
+ * <p>Rewrite logic:
+ * change the collations field to reference the new input.
+ *
+ * @param rel Sort to be rewritten
+ */
+ public Frame decorrelateRel(Sort rel) {
+ // Sort itself should not reference corVars.
+ assert !cm.mapRefRelToCorRef.containsKey(rel);
+
+ // Sort only references field positions in collations field.
+ // The collations field in the newRel now need to refer to the
+ // new output positions in its input.
+ // Its output does not change the input ordering, so there's no
+ // need to call propagateExpr.
+ final RelNode oldInput = rel.getInput();
+ final Frame frame = getInvoke(oldInput);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+ final RelNode newInput = frame.r;
+
+ Mappings.TargetMapping mapping =
+ Mappings.target(frame.oldToNewOutputs,
+ oldInput.getRowType().getFieldCount(),
+ newInput.getRowType().getFieldCount());
+
+ RelCollation oldCollation = rel.getCollation();
+ RelCollation newCollation = RexUtil.apply(mapping, oldCollation);
+
+ final Sort newSort = LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
+
+ // Sort does not change input ordering
+ return new Frame(rel, newSort, frame.c, frame.oldToNewOutputs);
+ }
+
+ /**
+ * Rewrites a {@link Values}.
+ *
+ * @param rel Values to be rewritten
+ */
+ public Frame decorrelateRel(Values rel) {
+ // There are no inputs, so rel does not need to be changed.
+ return null;
+ }
+
+ public Frame decorrelateRel(LogicalCorrelate rel) {
+ // does not allow correlation condition in its inputs now, so choose default behavior
+ return decorrelateRel((RelNode) rel);
+ }
+
+ /** Fallback if none of the other {@code decorrelateRel} methods match. */
+ public Frame decorrelateRel(RelNode rel) {
+ RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());
+ if (rel.getInputs().size() > 0) {
+ List<RelNode> oldInputs = rel.getInputs();
+ List<RelNode> newInputs = new ArrayList<>();
+ for (int i = 0; i < oldInputs.size(); ++i) {
+ final Frame frame = getInvoke(oldInputs.get(i));
+ if (frame == null || frame.c != null) {
+ // if input is not rewritten, or if it produces correlated variables, terminate rewrite
+ return null;
+ }
+ newInputs.add(frame.r);
+ newRel.replaceInput(i, frame.r);
+ }
+
+ if (!Util.equalShallow(oldInputs, newInputs)) {
+ newRel = rel.copy(rel.getTraitSet(), newInputs);
+ }
+ }
+ // the output position should not change since there are no corVars coming from below.
+ return new Frame(rel, newRel, null, identityMap(rel.getRowType().getFieldCount()));
+ }
+
+ /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */
+ private static Map<Integer, Integer> identityMap(int count) {
+ com.google.common.collect.ImmutableMap.Builder<Integer, Integer> builder =
+ com.google.common.collect.ImmutableMap.builder();
+ for (int i = 0; i < count; i++) {
+ builder.put(i, i);
+ }
+ return builder.build();
+ }
+
+ /** Returns a literal output field, or null if it is not literal. */
+ private static RexLiteral projectedLiteral(RelNode rel, int i) {
+ if (rel instanceof Project) {
+ final Project project = (Project) rel;
+ final RexNode node = project.getProjects().get(i);
+ if (node instanceof RexLiteral) {
+ return (RexLiteral) node;
+ }
+ }
+ return null;
+ }
+ }
+
+ /** Builds a {@link CorelMap}. */
+ private static class CorelMapBuilder extends RelShuttleImpl {
+ private final int maxCnfNodeCount;
+ // nested correlation variables in SubQuery, such as:
+ // SELECT * FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t1.a = t2.c AND
+ // t2.d IN (SELECT t3.d FROM t3 WHERE t1.b = t3.e)
+ boolean hasNestedCorScope = false;
+ // has unsupported correlation condition, such as:
+ // SELECT * FROM l WHERE a IN (SELECT c FROM r WHERE l.b IN (SELECT e FROM t))
+ // SELECT a FROM l WHERE b IN (SELECT r1.e FROM r1 WHERE l.a = r1.d UNION SELECT r2.i FROM r2)
+ // SELECT * FROM l WHERE EXISTS (SELECT * FROM r LEFT JOIN (SELECT * FROM t WHERE t.j = l.b) t1 ON r.f = t1.k)
+ // SELECT * FROM l WHERE b IN (SELECT MIN(e) FROM r WHERE l.c > r.f)
+ // SELECT * FROM l WHERE b IN (SELECT MIN(e) OVER() FROM r WHERE l.c > r.f)
+ boolean hasUnsupportedCorCondition = false;
+ // true if SubQuery rel tree has Aggregate node, else false.
+ boolean hasAggregateNode = false;
+ // true if SubQuery rel tree has Over node, else false.
+ boolean hasOverNode = false;
+
+ public CorelMapBuilder(int maxCnfNodeCount) {
+ this.maxCnfNodeCount = maxCnfNodeCount;
+ }
+
+ final SortedMap<CorrelationId, RelNode> mapCorToCorRel = new TreeMap<>();
+ final com.google.common.collect.SortedSetMultimap<RelNode, CorRef> mapRefRelToCorRef =
+ com.google.common.collect.Multimaps.newSortedSetMultimap(
+ new HashMap<RelNode, Collection<CorRef>>(),
+ new com.google.common.base.Supplier<TreeSet<CorRef>>() {
+ public TreeSet<CorRef> get() {
+ Bug.upgrade("use MultimapBuilder when we're on Guava-16");
+ return com.google.common.collect.Sets.newTreeSet();
+ }
+ });
+ final Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar = new HashMap<>();
+ final Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet = new HashMap<>();
+
+ int corrIdGenerator = 0;
+ final Deque<RelNode> corNodeStack = new ArrayDeque<>();
+
+ /** Creates a CorelMap by iterating over a {@link RelNode} tree. */
+ CorelMap build(RelNode... rels) {
+ for (RelNode rel : rels) {
+ stripHep(rel).accept(this);
+ }
+ return CorelMap.of(mapRefRelToCorRef, mapCorToCorRel, mapSubQueryNodeToCorSet);
+ }
+
+ @Override
+ protected RelNode visitChild(RelNode parent, int i, RelNode input) {
+ return super.visitChild(parent, i, stripHep(input));
+ }
+
+ @Override
+ public RelNode visit(LogicalCorrelate correlate) {
+ // TODO does not allow correlation condition in its inputs now
+ // If correlation conditions in correlate inputs reference to correlate outputs variable,
+ // that should not be supported, e.g.
+ // SELECT * FROM outer_table l WHERE l.c IN (
+ // SELECT f1 FROM (
+ // SELECT * FROM inner_table r WHERE r.d IN (SELECT x.i FROM x WHERE x.j = l.b)) t,
+ // LATERAL TABLE(table_func(t.f)) AS T(f1)
+ // ))
+ // other cases should be supported, e.g.
+ // SELECT * FROM outer_table l WHERE l.c IN (
+ // SELECT f1 FROM (
+ // SELECT * FROM inner_table r WHERE r.d IN (SELECT x.i FROM x WHERE x.j = r.e)) t,
+ // LATERAL TABLE(table_func(t.f)) AS T(f1)
+ // ))
+ checkCorConditionOfInput(correlate.getLeft());
+ checkCorConditionOfInput(correlate.getRight());
+
+ visitChild(correlate, 0, correlate.getLeft());
+ visitChild(correlate, 1, correlate.getRight());
+ return correlate;
+ }
+
+ @Override
+ public RelNode visit(LogicalJoin join) {
+ switch (join.getJoinType()) {
+ case LEFT:
+ checkCorConditionOfInput(join.getRight());
+ break;
+ case RIGHT:
+ checkCorConditionOfInput(join.getLeft());
+ break;
+ case FULL:
+ checkCorConditionOfInput(join.getLeft());
+ checkCorConditionOfInput(join.getRight());
+ break;
+ default:
+ break;
+ }
+
+ final boolean hasSubQuery = RexUtil.SubQueryFinder.find(join.getCondition()) != null;
+ try {
+ if (!corNodeStack.isEmpty()) {
+ mapSubQueryNodeToCorSet.put(join, corNodeStack.peek().getVariablesSet());
+ }
+ if (hasSubQuery) {
+ corNodeStack.push(join);
+ }
+ checkCorCondition(join);
+ join.getCondition().accept(rexVisitor(join));
+ } finally {
+ if (hasSubQuery) {
+ corNodeStack.pop();
+ }
+ }
+ visitChild(join, 0, join.getLeft());
+ visitChild(join, 1, join.getRight());
+ return join;
+ }
+
+ @Override
+ public RelNode visit(LogicalFilter filter) {
+ final boolean hasSubQuery = RexUtil.SubQueryFinder.find(filter.getCondition()) != null;
+ try {
+ if (!corNodeStack.isEmpty()) {
+ mapSubQueryNodeToCorSet.put(filter, corNodeStack.peek().getVariablesSet());
+ }
+ if (hasSubQuery) {
+ corNodeStack.push(filter);
+ }
+ checkCorCondition(filter);
+ filter.getCondition().accept(rexVisitor(filter));
+ for (CorrelationId correlationId : filter.getVariablesSet()) {
+ mapCorToCorRel.put(correlationId, filter);
+ }
+ } finally {
+ if (hasSubQuery) {
+ corNodeStack.pop();
+ }
+ }
+ return super.visit(filter);
+ }
+
+ @Override
+ public RelNode visit(LogicalProject project) {
+ hasOverNode = RexOver.containsOver(project.getProjects(), null);
+ final boolean hasSubQuery = RexUtil.SubQueryFinder.find(project.getProjects()) != null;
+ try {
+ if (!corNodeStack.isEmpty()) {
+ mapSubQueryNodeToCorSet.put(project, corNodeStack.peek().getVariablesSet());
+ }
+ if (hasSubQuery) {
+ corNodeStack.push(project);
+ }
+ checkCorCondition(project);
+ for (RexNode node : project.getProjects()) {
+ node.accept(rexVisitor(project));
+ }
+ } finally {
+ if (hasSubQuery) {
+ corNodeStack.pop();
+ }
+ }
+ return super.visit(project);
+ }
+
+ @Override
+ public RelNode visit(LogicalAggregate aggregate) {
+ hasAggregateNode = true;
+ return super.visit(aggregate);
+ }
+
+ @Override
+ public RelNode visit(LogicalUnion union) {
+ checkCorConditionOfSetOpInputs(union);
+ return super.visit(union);
+ }
+
+ @Override
+ public RelNode visit(LogicalMinus minus) {
+ checkCorConditionOfSetOpInputs(minus);
+ return super.visit(minus);
+ }
+
+ @Override
+ public RelNode visit(LogicalIntersect intersect) {
+ checkCorConditionOfSetOpInputs(intersect);
+ return super.visit(intersect);
+ }
+
+ /**
+ * check whether the predicate on filter has unsupported correlation condition.
+ * e.g. SELECT * FROM l WHERE a IN (SELECT c FROM r WHERE l.b = r.d OR r.d > 10)
+ */
+ private void checkCorCondition(final LogicalFilter filter) {
+ if (mapSubQueryNodeToCorSet.containsKey(filter) && !hasUnsupportedCorCondition) {
+ final List<RexNode> corConditions = new ArrayList<>();
+ final List<RexNode> unsupportedCorConditions = new ArrayList<>();
+ analyzeCorConditions(
+ mapSubQueryNodeToCorSet.get(filter),
+ filter.getCondition(),
+ filter.getCluster().getRexBuilder(),
+ maxCnfNodeCount,
+ corConditions,
+ new ArrayList<>(),
+ unsupportedCorConditions);
+ if (!unsupportedCorConditions.isEmpty()) {
+ hasUnsupportedCorCondition = true;
+ } else if (!corConditions.isEmpty()) {
+ boolean hasNonEquals = false;
+ for (RexNode node : corConditions) {
+ if (node instanceof RexCall && ((RexCall) node).getOperator() != SqlStdOperatorTable.EQUALS) {
+ hasNonEquals = true;
+ break;
+ }
+ }
+ // agg or over with non-equality correlation condition is unsupported, e.g.
+ // SELECT * FROM l WHERE b IN (SELECT MIN(e) FROM r WHERE l.c > r.f)
+ // SELECT * FROM l WHERE b IN (SELECT MIN(e) OVER() FROM r WHERE l.c > r.f)
+ hasUnsupportedCorCondition = hasNonEquals && (hasAggregateNode || hasOverNode);
+ }
+ }
+ }
+
+ /**
+ * check whether the predicate on join has unsupported correlation condition.
+ * e.g. SELECT * FROM l WHERE a IN (SELECT c FROM r WHERE l.b IN (SELECT e FROM t))
+ */
+ private void checkCorCondition(final LogicalJoin join) {
+ if (!hasUnsupportedCorCondition) {
+ join.getCondition().accept(new RexVisitorImpl<Void>(true) {
+ @Override
+ public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
+ hasUnsupportedCorCondition = true;
+ return super.visitCorrelVariable(correlVariable);
+ }
+ });
+ }
+ }
+
+ /**
+ * check whether the project has correlation expressions.
+ * e.g. SELECT * FROM l WHERE a IN (SELECT l.b FROM r)
+ */
+ private void checkCorCondition(final LogicalProject project) {
+ if (!hasUnsupportedCorCondition) {
+ for (RexNode node : project.getProjects()) {
+ node.accept(new RexVisitorImpl<Void>(true) {
+ @Override
+ public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
+ hasUnsupportedCorCondition = true;
+ return super.visitCorrelVariable(correlVariable);
+ }
+ });
+ }
+ }
+ }
+
+ /**
+ * check whether a node has some input which have correlation condition.
+ * e.g. SELECT * FROM l WHERE EXISTS (SELECT * FROM r LEFT JOIN (SELECT * FROM t WHERE t.j=l.b) t1 ON r.f=t1.k)
+ * the above sql can not be converted to semi-join plan,
+ * because the right input of Left-Join has the correlation condition(t.j=l.b).
+ */
+ private void checkCorConditionOfInput(final RelNode input) {
+ final RelShuttleImpl shuttle = new RelShuttleImpl() {
+ final RexVisitor<Void> visitor = new RexVisitorImpl<Void>(true) {
+ @Override
+ public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
+ hasUnsupportedCorCondition = true;
+ return super.visitCorrelVariable(correlVariable);
+ }
+ };
+
+ @Override
+ public RelNode visit(LogicalFilter filter) {
+ filter.getCondition().accept(visitor);
+ return super.visit(filter);
+ }
+
+ @Override
+ public RelNode visit(LogicalProject project) {
+ for (RexNode rex : project.getProjects()) {
+ rex.accept(visitor);
+ }
+ return super.visit(project);
+ }
+
+ @Override
+ public RelNode visit(LogicalJoin join) {
+ join.getCondition().accept(visitor);
+ return super.visit(join);
+ }
+ };
+ input.accept(shuttle);
+ }
+
+ /**
+ * check whether a SetOp has some children node which have correlation condition.
+ * e.g. SELECT a FROM l WHERE b IN (SELECT r1.e FROM r1 WHERE l.a = r1.d UNION SELECT r2.i FROM r2)
+ */
+ private void checkCorConditionOfSetOpInputs(SetOp setOp) {
+ for (RelNode child : setOp.getInputs()) {
+ checkCorConditionOfInput(child);
+ }
+ }
+
+ private RexVisitorImpl<Void> rexVisitor(final RelNode rel) {
+ return new RexVisitorImpl<Void>(true) {
+ @Override
+ public Void visitSubQuery(RexSubQuery subQuery) {
+ hasAggregateNode = false; // reset to default value
+ hasOverNode = false; // reset to default value
+ subQuery.rel.accept(CorelMapBuilder.this);
+ return super.visitSubQuery(subQuery);
+ }
+
+ @Override
+ public Void visitFieldAccess(RexFieldAccess fieldAccess) {
+ final RexNode ref = fieldAccess.getReferenceExpr();
+ if (ref instanceof RexCorrelVariable) {
+ final RexCorrelVariable var = (RexCorrelVariable) ref;
+ // check the scope of correlation id
+ // we do not support nested correlation variables in SubQuery, such as:
+ // select * from t1 where exists (select * from t2 where t1.a = t2.c and
+ // t2.d in (select t3.d from t3 where t1.b = t3.e)
+ if (!hasUnsupportedCorCondition) {
+ hasUnsupportedCorCondition = !mapSubQueryNodeToCorSet.containsKey(rel);
+ }
+ if (!hasNestedCorScope && mapSubQueryNodeToCorSet.containsKey(rel)) {
+ hasNestedCorScope = !mapSubQueryNodeToCorSet.get(rel).contains(var.id);
+ }
+
+ if (mapFieldAccessToCorVar.containsKey(fieldAccess)) {
+ // for cases where different Rel nodes are referring to
+ // same correlation var (e.g. in case of NOT IN)
+ // avoid generating another correlation var
+ // and record the 'rel' is using the same correlation
+ mapRefRelToCorRef.put(rel, mapFieldAccessToCorVar.get(fieldAccess));
+ } else {
+ final CorRef correlation = new CorRef(
+ var.id,
+ fieldAccess.getField().getIndex(),
+ corrIdGenerator++);
+ mapFieldAccessToCorVar.put(fieldAccess, correlation);
+ mapRefRelToCorRef.put(rel, correlation);
+ }
+ }
+ return super.visitFieldAccess(fieldAccess);
+ }
+ };
+ }
+ }
+
+ /**
+ * A unique reference to a correlation field.
+ *
+ * <p>For instance, if a RelNode references emp.name multiple times, it would
+ * result in multiple {@code CorRef} objects that differ just in
+ * {@link CorRef#uniqueKey}.
+ */
+ private static class CorRef implements Comparable<CorRef> {
+ final int uniqueKey;
+ final CorrelationId corr;
+ final int field;
+
+ CorRef(CorrelationId corr, int field, int uniqueKey) {
+ this.corr = corr;
+ this.field = field;
+ this.uniqueKey = uniqueKey;
+ }
+
+ @Override
+ public String toString() {
+ return corr.getName() + '.' + field;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(uniqueKey, corr, field);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return this == o
+ || o instanceof CorRef
+ && uniqueKey == ((CorRef) o).uniqueKey
+ && corr == ((CorRef) o).corr
+ && field == ((CorRef) o).field;
+ }
+
+ public int compareTo(@Nonnull CorRef o) {
+ int c = corr.compareTo(o.corr);
+ if (c != 0) {
+ return c;
+ }
+ c = Integer.compare(field, o.field);
+ if (c != 0) {
+ return c;
+ }
+ return Integer.compare(uniqueKey, o.uniqueKey);
+ }
+ }
+
+ /**
+ * A map of the locations of correlation variables in a tree of {@link RelNode}s.
+ *
+ * <p>It is used to drive the decorrelation process.
+ * Treat it as immutable; rebuild if you modify the tree.
+ *
+ * <p>There are three maps:<ol>
+ *
+ * <li>{@link #mapRefRelToCorRef} maps a {@link RelNode} to the correlated variables it references;
+ *
+ * <li>{@link #mapCorToCorRel} maps a correlated variable to the {@link RelNode} providing it;
+ *
+ * <li>{@link #mapSubQueryNodeToCorSet} maps a {@link RelNode} to the correlated variables it has;
+ *
+ * </ol>
+ */
+ private static class CorelMap {
+ private final com.google.common.collect.Multimap<RelNode, CorRef> mapRefRelToCorRef;
+ private final SortedMap<CorrelationId, RelNode> mapCorToCorRel;
+ private final Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet;
+
+ // TODO: create immutable copies of all maps
+ private CorelMap(
+ com.google.common.collect.Multimap<RelNode, CorRef> mapRefRelToCorRef,
+ SortedMap<CorrelationId, RelNode> mapCorToCorRel,
+ Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet) {
+ this.mapRefRelToCorRef = mapRefRelToCorRef;
+ this.mapCorToCorRel = mapCorToCorRel;
+ this.mapSubQueryNodeToCorSet = com.google.common.collect.ImmutableMap.copyOf(mapSubQueryNodeToCorSet);
+ }
+
+ @Override
+ public String toString() {
+ return "mapRefRelToCorRef=" + mapRefRelToCorRef
+ + "\nmapCorToCorRel=" + mapCorToCorRel
+ + "\nmapSubQueryNodeToCorSet=" + mapSubQueryNodeToCorSet
+ + "\n";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj == this
+ || obj instanceof CorelMap
+ && mapRefRelToCorRef.equals(((CorelMap) obj).mapRefRelToCorRef)
+ && mapCorToCorRel.equals(((CorelMap) obj).mapCorToCorRel)
+ && mapSubQueryNodeToCorSet.equals(((CorelMap) obj).mapSubQueryNodeToCorSet);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(mapRefRelToCorRef, mapCorToCorRel, mapSubQueryNodeToCorSet);
+ }
+
+ /** Creates a CorelMap with given contents. */
+ public static CorelMap of(
+ com.google.common.collect.SortedSetMultimap<RelNode, CorRef> mapRefRelToCorVar,
+ SortedMap<CorrelationId, RelNode> mapCorToCorRel,
+ Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet) {
+ return new CorelMap(mapRefRelToCorVar, mapCorToCorRel, mapSubQueryNodeToCorSet);
+ }
+
+ /**
+ * Returns whether there are any correlating variables in this statement.
+ *
+ * @return whether there are any correlating variables
+ */
+ boolean hasCorrelation() {
+ return !mapCorToCorRel.isEmpty();
+ }
+ }
+
+ /**
+ * Frame describing the relational expression after decorrelation
+ * and where to find the output fields and correlation condition.
+ */
+ private static class Frame {
+ // the new rel
+ final RelNode r;
+ // the condition contains correlation variables
+ final RexNode c;
+ // map the oldRel's field indices to newRel's field indices
+ final com.google.common.collect.ImmutableSortedMap<Integer, Integer> oldToNewOutputs;
+
+ Frame(RelNode oldRel, RelNode newRel, RexNode corCondition, Map<Integer, Integer> oldToNewOutputs) {
+ this.r = Preconditions.checkNotNull(newRel);
+ this.c = corCondition;
+ this.oldToNewOutputs = com.google.common.collect.ImmutableSortedMap.copyOf(oldToNewOutputs);
+ assert allLessThan(this.oldToNewOutputs.keySet(), oldRel.getRowType().getFieldCount(), Litmus.THROW);
+ assert allLessThan(this.oldToNewOutputs.values(), r.getRowType().getFieldCount(), Litmus.THROW);
+ }
+
+ List<Integer> getCorInputRefIndices() {
+ final List<Integer> inputRefIndices;
+ if (c != null) {
+ inputRefIndices = RelOptUtil.InputFinder.bits(c).toList();
+ } else {
+ inputRefIndices = new ArrayList<>();
+ }
+ return inputRefIndices;
+ }
+
+ private static boolean allLessThan(Collection<Integer> integers, int limit, Litmus ret) {
+ for (int value : integers) {
+ if (value >= limit) {
+ return ret.fail("out of range; value: {}, limit: {}", value, limit);
+ }
+ }
+ return ret.succeed();
+ }
+ }
+
+ /**
+ * Result describing the relational expression after decorrelation
+ * and where to find the equivalent non-correlated expressions and correlated conditions.
+ */
+ public static class Result {
+ private final com.google.common.collect.ImmutableMap<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap;
+ static final Result EMPTY = new Result(new HashMap<>());
+
+ private Result(Map<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap) {
+ this.subQueryMap = com.google.common.collect.ImmutableMap.copyOf(subQueryMap);
+ }
+
+ public Pair<RelNode, RexNode> getSubQueryEquivalent(RexSubQuery subQuery) {
+ return subQueryMap.get(subQuery);
+ }
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/JoinTypeUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/JoinTypeUtil.java
index 53c8d41..33ad1fe 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/JoinTypeUtil.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/JoinTypeUtil.java
@@ -20,9 +20,7 @@ package org.apache.flink.table.plan.util;
import org.apache.flink.table.runtime.join.FlinkJoinType;
-import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.core.SemiJoin;
/**
* Utility for {@link FlinkJoinType}.
@@ -32,7 +30,7 @@ public class JoinTypeUtil {
/**
* Converts {@link JoinRelType} to {@link FlinkJoinType}.
*/
- public static FlinkJoinType toFlinkJoinType(JoinRelType joinRelType) {
+ public static FlinkJoinType getFlinkJoinType(JoinRelType joinRelType) {
switch (joinRelType) {
case INNER:
return FlinkJoinType.INNER;
@@ -42,39 +40,13 @@ public class JoinTypeUtil {
return FlinkJoinType.RIGHT;
case FULL:
return FlinkJoinType.FULL;
+ case SEMI:
+ return FlinkJoinType.SEMI;
+ case ANTI:
+ return FlinkJoinType.ANTI;
default:
throw new IllegalArgumentException("invalid: " + joinRelType);
}
}
- /**
- * Gets {@link FlinkJoinType} of the input Join RelNode.
- */
- public static FlinkJoinType getFlinkJoinType(Join join) {
- if (join instanceof SemiJoin) {
- // TODO supports ANTI
- return FlinkJoinType.SEMI;
- } else {
- return toFlinkJoinType(join.getJoinType());
- }
- }
-
- /**
- * Converts {@link FlinkJoinType} to {@link JoinRelType}.
- */
- public static JoinRelType toJoinRelType(FlinkJoinType joinType) {
- switch (joinType) {
- case INNER:
- return JoinRelType.INNER;
- case LEFT:
- return JoinRelType.LEFT;
- case RIGHT:
- return JoinRelType.RIGHT;
- case FULL:
- return JoinRelType.FULL;
- default:
- throw new IllegalArgumentException("invalid: " + joinType);
- }
- }
-
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
index 741c81d..7675b84 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
@@ -86,8 +86,6 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
aggregate.getNamedProperties,
convAggregate)
- // TODO supports SemiJoin
-
case watermarkAssigner: LogicalWatermarkAssigner =>
watermarkAssigner
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
index c3c839c..25754a9 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
@@ -395,25 +395,22 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
- areColumnsUniqueOfJoin(
- rel.analyzeCondition(),
- rel.getJoinType,
- rel.getLeft.getRowType,
- (leftSet: ImmutableBitSet) => mq.areColumnsUnique(rel.getLeft, leftSet, ignoreNulls),
- (rightSet: ImmutableBitSet) => mq.areColumnsUnique(rel.getRight, rightSet, ignoreNulls),
- mq,
- columns
- )
- }
-
- def areColumnsUnique(
- rel: SemiJoin,
- mq: RelMetadataQuery,
- columns: ImmutableBitSet,
- ignoreNulls: Boolean): JBoolean = {
- // only return the unique keys from the LHS since a semijoin only
- // returns the LHS
- mq.areColumnsUnique(rel.getLeft, columns, ignoreNulls)
+ rel.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // only return the unique keys from the LHS since a SEMI/ANTI join only
+ // returns the LHS
+ mq.areColumnsUnique(rel.getLeft, columns, ignoreNulls)
+ case _ =>
+ areColumnsUniqueOfJoin(
+ rel.analyzeCondition(),
+ rel.getJoinType,
+ rel.getLeft.getRowType,
+ (leftSet: ImmutableBitSet) => mq.areColumnsUnique(rel.getLeft, leftSet, ignoreNulls),
+ (rightSet: ImmutableBitSet) => mq.areColumnsUnique(rel.getRight, rightSet, ignoreNulls),
+ mq,
+ columns
+ )
+ }
}
def areColumnsUnique(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
index bdb0ff8..049bdaa 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
@@ -467,27 +467,19 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
return 1D
}
}
- RelMdUtil.getJoinDistinctRowCount(mq, rel, rel.getJoinType, groupKey, predicate, false)
- }
-
- def getDistinctRowCount(
- rel: SemiJoin,
- mq: RelMetadataQuery,
- groupKey: ImmutableBitSet,
- predicate: RexNode): JDouble = {
- if (predicate == null || predicate.isAlwaysTrue) {
- if (groupKey.isEmpty) {
- return 1D
- }
- }
- // create a RexNode representing the selectivity of the
- // semijoin filter and pass it to getDistinctRowCount
- var newPred = FlinkRelMdUtil.makeSemiJoinSelectivityRexNode(mq, rel)
- if (predicate != null) {
- val rexBuilder = rel.getCluster.getRexBuilder
- newPred = rexBuilder.makeCall(SqlStdOperatorTable.AND, newPred, predicate)
+ rel.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // create a RexNode representing the selectivity of the
+ // semi-join filter and pass it to getDistinctRowCount
+ var newPred = FlinkRelMdUtil.makeSemiAntiJoinSelectivityRexNode(mq, rel)
+ if (predicate != null) {
+ val rexBuilder = rel.getCluster.getRexBuilder
+ newPred = rexBuilder.makeCall(SqlStdOperatorTable.AND, newPred, predicate)
+ }
+ mq.getDistinctRowCount(rel.getLeft, groupKey, newPred)
+ case _ =>
+ RelMdUtil.getJoinDistinctRowCount(mq, rel, rel.getJoinType, groupKey, predicate, false)
}
- mq.getDistinctRowCount(rel.getLeft, groupKey, newPred)
}
def getDistinctRowCount(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
index 26a6e0e..d82f4f9 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
@@ -29,8 +29,6 @@ import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecCorrelate, Bat
import org.apache.flink.table.plan.nodes.physical.stream._
import org.apache.flink.table.plan.schema.DataStreamTable
import org.apache.flink.table.plan.stats.{WithLower, WithUpper}
-import org.apache.flink.table.plan.util.JoinTypeUtil
-import org.apache.flink.table.runtime.join.FlinkJoinType
import org.apache.flink.table.{JByte, JDouble, JFloat, JList, JLong, JShort}
import org.apache.calcite.plan.hep.HepRelVertex
@@ -376,7 +374,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
val joinInfo = rel.analyzeCondition
val leftKeys = joinInfo.leftKeys
val rightKeys = joinInfo.rightKeys
- val joinType = JoinTypeUtil.toFlinkJoinType(rel.getJoinType)
+ val joinType = rel.getJoinType
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
// if group set contains update return null
@@ -391,11 +389,11 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
val isKeyAllAppend = isAllConstantOnKeys(left, leftKeys.toIntArray) &&
isAllConstantOnKeys(right, rightKeys.toIntArray)
- if (!containDelete && !joinType.equals(FlinkJoinType.ANTI) && isKeyAllAppend &&
+ if (!containDelete && !joinType.equals(JoinRelType.ANTI) && isKeyAllAppend &&
(containUpdate && joinInfo.isEqui || !containUpdate)) {
// output rowtype of semi equals to the rowtype of left child
- if (joinType.equals(FlinkJoinType.SEMI)) {
+ if (joinType.equals(JoinRelType.SEMI)) {
fmq.getRelModifiedMonotonicity(left)
} else {
val leftFieldMonotonicities = fmq.getRelModifiedMonotonicity(left).fieldMonotonicities
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPercentageOriginalRows.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPercentageOriginalRows.scala
index 91486ab..21ba508 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPercentageOriginalRows.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPercentageOriginalRows.scala
@@ -24,7 +24,7 @@ import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateB
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{Aggregate, Join, SemiJoin, Union}
+import org.apache.calcite.rel.core.{Aggregate, Join, JoinRelType, Union}
import org.apache.calcite.rel.metadata._
import org.apache.calcite.util.{BuiltInMethod, Util}
@@ -57,18 +57,19 @@ class FlinkRelMdPercentageOriginalRows private
def getPercentageOriginalRows(rel: Join, mq: RelMetadataQuery): JDouble = {
val left: JDouble = mq.getPercentageOriginalRows(rel.getLeft)
- val right: JDouble = mq.getPercentageOriginalRows(rel.getRight)
- if (left == null || right == null) {
- null
- } else {
- left * right
+ rel.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ left
+ case _ =>
+ val right: JDouble = mq.getPercentageOriginalRows(rel.getRight)
+ if (left == null || right == null) {
+ null
+ } else {
+ left * right
+ }
}
}
- def getPercentageOriginalRows(rel: SemiJoin, mq: RelMetadataQuery): JDouble = {
- mq.getPercentageOriginalRows(rel.getLeft)
- }
-
def getPercentageOriginalRows(rel: Union, mq: RelMetadataQuery): JDouble = {
var numerator: JDouble = 0.0
var denominator: JDouble = 0.0
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
index 2c03049..d927a6d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
@@ -302,12 +302,14 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P
def getPopulationSize(
rel: Join,
mq: RelMetadataQuery,
- groupKey: ImmutableBitSet): JDouble = RelMdUtil.getJoinPopulationSize(mq, rel, groupKey)
-
- def getPopulationSize(
- rel: SemiJoin,
- mq: RelMetadataQuery,
- groupKey: ImmutableBitSet): JDouble = mq.getPopulationSize(rel.getLeft, groupKey)
+ groupKey: ImmutableBitSet): JDouble = {
+ rel.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ mq.getPopulationSize(rel.getLeft, groupKey)
+ case _ =>
+ RelMdUtil.getJoinPopulationSize(mq, rel, groupKey)
+ }
+ }
def getPopulationSize(
rel: Union,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
index 65d7b40..77f967f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
@@ -197,6 +197,15 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
mq.getRowCount(overWindow.getInput)
def getRowCount(join: Join, mq: RelMetadataQuery): JDouble = {
+ join.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ val semiJoinSelectivity = FlinkRelMdUtil.makeSemiAntiJoinSelectivityRexNode(mq, join)
+ val selectivity = mq.getSelectivity(join.getLeft, semiJoinSelectivity)
+ val leftRowCount = mq.getRowCount(join.getLeft)
+ return NumberUtil.multiply(leftRowCount, selectivity)
+ case _ => // do nothing
+ }
+
val leftChild = join.getLeft
val rightChild = join.getRight
val leftRowCount = mq.getRowCount(leftChild)
@@ -313,13 +322,6 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
join.isSemiJoinDone)
}
- def getRowCount(rel: SemiJoin, mq: RelMetadataQuery): JDouble = {
- val semiJoinSelectivity = FlinkRelMdUtil.makeSemiJoinSelectivityRexNode(mq, rel)
- val selectivity = mq.getSelectivity(rel.getLeft, semiJoinSelectivity)
- val leftRowCount = mq.getRowCount(rel.getLeft)
- NumberUtil.multiply(leftRowCount, selectivity)
- }
-
def getRowCount(rel: Union, mq: RelMetadataQuery): JDouble = {
val rowCounts = rel.getInputs.map(mq.getRowCount)
if (rowCounts.contains(null)) {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
index 7879e3a..2c63e06 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
@@ -188,22 +188,19 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
if (predicate == null || predicate.isAlwaysTrue) {
1.0
} else {
- estimateSelectivity(rel, mq, predicate)
- }
- }
-
- def getSelectivity(rel: SemiJoin, mq: RelMetadataQuery, predicate: RexNode): JDouble = {
- if (predicate == null || predicate.isAlwaysTrue) {
- 1.0
- } else {
- // create a RexNode representing the selectivity of the
- // semijoin filter and pass it to getSelectivity
- val rexBuilder = rel.getCluster.getRexBuilder
- var newPred = FlinkRelMdUtil.makeSemiJoinSelectivityRexNode(mq, rel)
- if (predicate != null) {
- newPred = rexBuilder.makeCall(SqlStdOperatorTable.AND, newPred, predicate)
+ rel.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // create a RexNode representing the selectivity of the
+ // semi-join filter and pass it to getSelectivity
+ val rexBuilder = rel.getCluster.getRexBuilder
+ var newPred = FlinkRelMdUtil.makeSemiAntiJoinSelectivityRexNode(mq, rel)
+ if (predicate != null) {
+ newPred = rexBuilder.makeCall(SqlStdOperatorTable.AND, newPred, predicate)
+ }
+ mq.getSelectivity(rel.getLeft, newPred)
+ case _ =>
+ estimateSelectivity(rel, mq, predicate)
}
- mq.getSelectivity(rel.getLeft, newPred)
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
index 04bfcc7..c52c947 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
@@ -243,18 +243,13 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] {
getColumnSizesFromInputOrType(overWindow, mq, (0 until inputFieldCount).zipWithIndex.toMap)
}
- def averageColumnSizes(rel: Join, mq: RelMetadataQuery): JList[JDouble] =
- averageJoinColumnSizesOfJoin(rel, mq, isSemiJoin = false)
-
- def averageColumnSizes(rel: SemiJoin, mq: RelMetadataQuery): JList[JDouble] =
- averageJoinColumnSizesOfJoin(rel, mq, isSemiJoin = true)
-
- private def averageJoinColumnSizesOfJoin(
- join: Join,
- mq: RelMetadataQuery,
- isSemiJoin: Boolean): JList[JDouble] = {
- val acsOfLeft = mq.getAverageColumnSizes(join.getLeft)
- val acsOfRight = if (isSemiJoin) null else mq.getAverageColumnSizes(join.getRight)
+ def averageColumnSizes(rel: Join, mq: RelMetadataQuery): JList[JDouble] = {
+ val acsOfLeft = mq.getAverageColumnSizes(rel.getLeft)
+ val acsOfRight = if (rel.getJoinType.projectsRight) {
+ mq.getAverageColumnSizes(rel.getRight)
+ } else {
+ null
+ }
if (acsOfLeft == null && acsOfRight == null) {
null
} else if (acsOfRight == null) {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
index dacf30a..2c87feb 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
@@ -339,8 +339,15 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
join: Join,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- getJoinUniqueKeys(
- join.analyzeCondition(), join.getJoinType, join.getLeft, join.getRight, mq, ignoreNulls)
+ join.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // only return the unique keys from the LHS since a SEMI/ANTI join only
+ // returns the LHS
+ mq.getUniqueKeys(join.getLeft, ignoreNulls)
+ case _ =>
+ getJoinUniqueKeys(
+ join.analyzeCondition(), join.getJoinType, join.getLeft, join.getRight, mq, ignoreNulls)
+ }
}
private def getJoinUniqueKeys(
@@ -421,15 +428,6 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
retSet
}
- def getUniqueKeys(
- rel: SemiJoin,
- mq: RelMetadataQuery,
- ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- // only return the unique keys from the LHS since a semijoin only
- // returns the LHS
- mq.getUniqueKeys(rel.getLeft, ignoreNulls)
- }
-
// TODO supports temporal table join
def getUniqueKeys(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonPhysicalJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonPhysicalJoin.scala
index 12f08c9..65fec34 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonPhysicalJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonPhysicalJoin.scala
@@ -22,9 +22,11 @@ import org.apache.flink.table.plan.nodes.physical.FlinkPhysicalRel
import org.apache.flink.table.plan.util.{JoinTypeUtil, JoinUtil, RelExplainUtil}
import org.apache.flink.table.runtime.join.FlinkJoinType
-import org.apache.calcite.rel.RelWriter
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
-import org.apache.calcite.rel.core.{Join, SemiJoin}
+import org.apache.calcite.rel.core.{CorrelationId, Join, JoinInfo, JoinRelType}
+import org.apache.calcite.rel.{RelNode, RelWriter}
+import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.validate.SqlValidatorUtil
import org.apache.calcite.util.mapping.IntPair
@@ -36,31 +38,42 @@ import scala.collection.JavaConversions._
/**
* Base physical class for flink [[Join]].
*/
-trait CommonPhysicalJoin extends Join with FlinkPhysicalRel {
+abstract class CommonPhysicalJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType)
+ extends Join(cluster, traitSet, leftRel, rightRel, condition, Set.empty[CorrelationId], joinType)
+ with FlinkPhysicalRel {
- lazy val (joinInfo, filterNulls) = {
+ def getJoinInfo: JoinInfo = joinInfo
+
+ lazy val filterNulls: Array[Boolean] = {
val filterNulls = new util.ArrayList[java.lang.Boolean]
- val joinInfo = JoinUtil.createJoinInfo(getLeft, getRight, getCondition, filterNulls)
- (joinInfo, filterNulls.map(_.booleanValue()).toArray)
+ JoinUtil.createJoinInfo(getLeft, getRight, getCondition, filterNulls)
+ filterNulls.map(_.booleanValue()).toArray
}
- lazy val keyPairs: List[IntPair] = joinInfo.pairs.toList
+ lazy val keyPairs: List[IntPair] = getJoinInfo.pairs.toList
- lazy val flinkJoinType: FlinkJoinType = JoinTypeUtil.getFlinkJoinType(this)
+ // TODO remove FlinkJoinType
+ lazy val flinkJoinType: FlinkJoinType = JoinTypeUtil.getFlinkJoinType(this.getJoinType)
- lazy val inputRowType: RelDataType = this match {
- case sj: SemiJoin =>
- // Combines inputs' RowType, the result is different from SemiJoin's RowType.
+ lazy val inputRowType: RelDataType = joinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // Combines inputs' RowType, the result is different from SEMI/ANTI Join's RowType.
SqlValidatorUtil.deriveJoinRowType(
- sj.getLeft.getRowType,
- sj.getRight.getRowType,
+ getLeft.getRowType,
+ getRight.getRowType,
getJoinType,
- sj.getCluster.getTypeFactory,
+ getCluster.getTypeFactory,
null,
Collections.emptyList[RelDataTypeField]
)
- case j: Join => getRowType
- case _ => throw new IllegalArgumentException(s"Illegal join node: ${this.getRelTypeName}")
+ case _ =>
+ getRowType
}
override def explainTerms(pw: RelWriter): RelWriter = {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala
index 10ec3a8..998f2c7 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala
@@ -57,15 +57,20 @@ class FlinkLogicalJoin(
override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = {
val leftRowCnt = mq.getRowCount(getLeft)
val leftRowSize = mq.getAverageRowSize(getLeft)
-
val rightRowCnt = mq.getRowCount(getRight)
- val rightRowSize = mq.getAverageRowSize(getRight)
-
- val ioCost = (leftRowCnt * leftRowSize) + (rightRowCnt * rightRowSize)
- val cpuCost = leftRowCnt + rightRowCnt
- val rowCnt = leftRowCnt + rightRowCnt
- planner.getCostFactory.makeCost(rowCnt, cpuCost, ioCost)
+ joinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ val rightRowSize = mq.getAverageRowSize(getRight)
+ val ioCost = (leftRowCnt * leftRowSize) + (rightRowCnt * rightRowSize)
+ val cpuCost = leftRowCnt + rightRowCnt
+ val rowCnt = leftRowCnt + rightRowCnt
+ planner.getCostFactory.makeCost(rowCnt, cpuCost, ioCost)
+ case _ =>
+ val cpuCost = leftRowCnt + rightRowCnt
+ val ioCost = (leftRowCnt * leftRowSize) + rightRowCnt
+ planner.getCostFactory.makeCost(leftRowCnt, cpuCost, ioCost)
+ }
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala
index 8c74c09..579a706 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala
@@ -41,14 +41,19 @@ import scala.collection.JavaConversions._
/**
* Batch physical RelNode for hash [[Join]].
*/
-trait BatchExecHashJoinBase extends BatchExecJoinBase {
-
- // true if LHS is build side, else false
- val leftIsBuild: Boolean
- // true if build side is broadcast, else false
- val isBroadcast: Boolean
- val tryDistinctBuildRow: Boolean
- var haveInsertRf: Boolean
+class BatchExecHashJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType,
+ // true if LHS is build side, else false
+ val leftIsBuild: Boolean,
+ // true if build side is broadcast, else false
+ val isBroadcast: Boolean,
+ val tryDistinctBuildRow: Boolean)
+ extends BatchExecJoinBase(cluster, traitSet, leftRel, rightRel, condition, joinType) {
private val (leftKeys, rightKeys) =
JoinUtil.checkAndGetJoinKeys(keyPairs, getLeft, getRight, allowEmptyKey = true)
@@ -61,6 +66,25 @@ trait BatchExecHashJoinBase extends BatchExecJoinBase {
val hashJoinType: HashJoinType = HashJoinType.of(leftIsBuild, getJoinType.generatesNullsOnRight(),
getJoinType.generatesNullsOnLeft())
+ override def copy(
+ traitSet: RelTraitSet,
+ conditionExpr: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): Join = {
+ new BatchExecHashJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ conditionExpr,
+ joinType,
+ leftIsBuild,
+ isBroadcast,
+ tryDistinctBuildRow)
+ }
+
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.itemIf("isBroadcast", "true", isBroadcast)
@@ -123,37 +147,3 @@ trait BatchExecHashJoinBase extends BatchExecJoinBase {
throw new TableException("Implements this")
}
}
-
-class BatchExecHashJoin(
- cluster: RelOptCluster,
- traitSet: RelTraitSet,
- leftRel: RelNode,
- rightRel: RelNode,
- condition: RexNode,
- joinType: JoinRelType,
- val leftIsBuild: Boolean,
- val isBroadcast: Boolean,
- override var haveInsertRf: Boolean = false)
- extends Join(cluster, traitSet, leftRel, rightRel, condition, Set.empty[CorrelationId], joinType)
- with BatchExecHashJoinBase {
-
- override val tryDistinctBuildRow = false
-
- override def copy(
- traitSet: RelTraitSet,
- conditionExpr: RexNode,
- left: RelNode,
- right: RelNode,
- joinType: JoinRelType,
- semiJoinDone: Boolean): Join =
- new BatchExecHashJoin(
- cluster,
- traitSet,
- left,
- right,
- conditionExpr,
- joinType,
- leftIsBuild,
- isBroadcast,
- haveInsertRf)
-}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala
index 5f1d1e7..f22a64a 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala
@@ -25,13 +25,22 @@ import org.apache.flink.table.generated.GeneratedJoinCondition
import org.apache.flink.table.plan.nodes.common.CommonPhysicalJoin
import org.apache.flink.table.plan.nodes.exec.BatchExecNode
-import org.apache.calcite.rel.core.Join
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.{Join, JoinRelType}
+import org.apache.calcite.rex.RexNode
/**
* Batch physical RelNode for [[Join]]
*/
-trait BatchExecJoinBase
- extends CommonPhysicalJoin
+abstract class BatchExecJoinBase(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType)
+ extends CommonPhysicalJoin(cluster, traitSet, leftRel, rightRel, condition, joinType)
with BatchPhysicalRel
with BatchExecNode[BaseRow] {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala
index a1888dd..a9919f8 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala
@@ -39,12 +39,36 @@ import scala.collection.JavaConversions._
/**
* Batch physical RelNode for nested-loop [[Join]].
*/
-trait BatchExecNestedLoopJoinBase extends BatchExecJoinBase {
+class BatchExecNestedLoopJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType,
+ // true if LHS is build side, else RHS is build side
+ val leftIsBuild: Boolean,
+ // true if one side returns single row, else false
+ val singleRowJoin: Boolean)
+ extends BatchExecJoinBase(cluster, traitSet, leftRel, rightRel, condition, joinType) {
- // true if LHS is build side, else RHS is build side
- val leftIsBuild: Boolean
- // true if one side returns single row, else false
- val singleRowJoin: Boolean
+ override def copy(
+ traitSet: RelTraitSet,
+ conditionExpr: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): Join = {
+ new BatchExecNestedLoopJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ conditionExpr,
+ joinType,
+ leftIsBuild,
+ singleRowJoin)
+ }
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
@@ -99,34 +123,3 @@ trait BatchExecNestedLoopJoinBase extends BatchExecJoinBase {
}
}
-
-class BatchExecNestedLoopJoin(
- cluster: RelOptCluster,
- traitSet: RelTraitSet,
- leftRel: RelNode,
- rightRel: RelNode,
- condition: RexNode,
- joinType: JoinRelType,
- val leftIsBuild: Boolean,
- val singleRowJoin: Boolean)
- extends Join(cluster, traitSet, leftRel, rightRel, condition, Set.empty[CorrelationId], joinType)
- with BatchExecNestedLoopJoinBase {
-
- override def copy(
- traitSet: RelTraitSet,
- conditionExpr: RexNode,
- left: RelNode,
- right: RelNode,
- joinType: JoinRelType,
- semiJoinDone: Boolean): Join =
- new BatchExecNestedLoopJoin(
- cluster,
- traitSet,
- left,
- right,
- conditionExpr,
- joinType,
- leftIsBuild,
- singleRowJoin)
-}
-
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala
index 19fefe2..fd24c16 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala
@@ -28,7 +28,7 @@ import org.apache.flink.table.codegen.sort.SortCodeGenerator
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory}
import org.apache.flink.table.plan.nodes.ExpressionFormat
-import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode}
+import org.apache.flink.table.plan.nodes.exec.ExecNode
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, JoinUtil, SortUtil}
import org.apache.flink.table.runtime.join.{FlinkJoinType, SortMergeJoinOperator}
@@ -45,12 +45,18 @@ import scala.collection.JavaConversions._
/**
* Batch physical RelNode for sort-merge [[Join]].
*/
-trait BatchExecSortMergeJoinBase extends BatchExecJoinBase with BatchExecNode[BaseRow] {
-
- // true if LHS is sorted by left join keys, else false
- val leftSorted: Boolean
- // true if RHS is sorted by right join key, else false
- val rightSorted: Boolean
+class BatchExecSortMergeJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType,
+ // true if LHS is sorted by left join keys, else false
+ val leftSorted: Boolean,
+ // true if RHS is sorted by right join key, else false
+ val rightSorted: Boolean)
+ extends BatchExecJoinBase(cluster, traitSet, leftRel, rightRel, condition, joinType) {
protected lazy val (leftAllKey, rightAllKey) =
JoinUtil.checkAndGetJoinKeys(keyPairs, getLeft, getRight)
@@ -76,6 +82,24 @@ trait BatchExecSortMergeJoinBase extends BatchExecJoinBase with BatchExecNode[Ba
}
}
+ override def copy(
+ traitSet: RelTraitSet,
+ conditionExpr: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): Join = {
+ new BatchExecSortMergeJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ conditionExpr,
+ joinType,
+ leftSorted,
+ rightSorted)
+ }
+
override def explainTerms(pw: RelWriter): RelWriter =
super.explainTerms(pw)
.itemIf("leftSorted", leftSorted, leftSorted)
@@ -226,33 +250,3 @@ object SortMergeJoinType extends Enumeration {
// both LHS and RHS need sort
SortMergeJoin = Value
}
-
-class BatchExecSortMergeJoin(
- cluster: RelOptCluster,
- traitSet: RelTraitSet,
- leftRel: RelNode,
- rightRel: RelNode,
- condition: RexNode,
- joinType: JoinRelType,
- override val leftSorted: Boolean,
- override val rightSorted: Boolean)
- extends Join(cluster, traitSet, leftRel, rightRel, condition, Set.empty[CorrelationId], joinType)
- with BatchExecSortMergeJoinBase {
-
- override def copy(
- traitSet: RelTraitSet,
- conditionExpr: RexNode,
- left: RelNode,
- right: RelNode,
- joinType: JoinRelType,
- semiJoinDone: Boolean): Join =
- new BatchExecSortMergeJoin(
- cluster,
- traitSet,
- left,
- right,
- conditionExpr,
- joinType,
- leftSorted,
- rightSorted)
-}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoinBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoin.scala
similarity index 94%
rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoinBase.scala
rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoin.scala
index e9fb5a7..628fd0e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoinBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecJoin.scala
@@ -28,7 +28,7 @@ import org.apache.flink.table.runtime.join.FlinkJoinType
import org.apache.calcite.plan._
import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{CorrelationId, Join, JoinRelType}
+import org.apache.calcite.rel.core.{Join, JoinRelType}
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rex.RexNode
@@ -42,8 +42,14 @@ import scala.collection.JavaConversions._
* Regular joins are the most generic type of join in which any new records or changes to
* either side of the join input are visible and are affecting the whole join result.
*/
-trait StreamExecJoinBase
- extends CommonPhysicalJoin
+class StreamExecJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType)
+ extends CommonPhysicalJoin(cluster, traitSet, leftRel, rightRel, condition, joinType)
with StreamPhysicalRel
with StreamExecNode[BaseRow] {
@@ -88,6 +94,16 @@ trait StreamExecJoinBase
override def requireWatermark: Boolean = false
+ override def copy(
+ traitSet: RelTraitSet,
+ conditionExpr: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): Join = {
+ new StreamExecJoin(cluster, traitSet, left, right, conditionExpr, joinType)
+ }
+
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val elementRate = 100.0d * 2 // two input stream
planner.getCostFactory.makeCost(elementRate, elementRate, 0)
@@ -111,24 +127,3 @@ trait StreamExecJoinBase
}
}
-
-class StreamExecJoin(
- cluster: RelOptCluster,
- traitSet: RelTraitSet,
- leftRel: RelNode,
- rightRel: RelNode,
- condition: RexNode,
- joinType: JoinRelType)
- extends Join(cluster, traitSet, leftRel, rightRel, condition, Set.empty[CorrelationId], joinType)
- with StreamExecJoinBase {
-
- override def copy(
- traitSet: RelTraitSet,
- conditionExpr: RexNode,
- left: RelNode,
- right: RelNode,
- joinType: JoinRelType,
- semiJoinDone: Boolean): Join = {
- new StreamExecJoin(cluster, traitSet, left, right, conditionExpr, joinType)
- }
-}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala
index abc3c51..4da273f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala
@@ -65,7 +65,8 @@ class StreamExecWindowJoin(
with StreamPhysicalRel
with StreamExecNode[BaseRow] {
- private lazy val flinkJoinType: FlinkJoinType = JoinTypeUtil.toFlinkJoinType(joinType)
+ // TODO remove FlinkJoinType
+ private lazy val flinkJoinType: FlinkJoinType = JoinTypeUtil.getFlinkJoinType(joinType)
override def producesUpdates: Boolean = false
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala
index 603277a..d91779b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala
@@ -40,9 +40,20 @@ object FlinkBatchProgram {
val chainedProgram = new FlinkChainedProgram[BatchOptimizeContext]()
chainedProgram.addLast(
- // rewrite sub-queries to joins
+ // rewrite sub-queries to joins
SUBQUERY_REWRITE,
FlinkGroupProgramBuilder.newBuilder[BatchOptimizeContext]
+ // rewrite RelTable before rewriting sub-queries
+ .addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkBatchRuleSets.TABLE_REF_RULES)
+ .build(), "convert table references before rewriting sub-queries to semi-join")
+ .addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkBatchRuleSets.SEMI_JOIN_RULES)
+ .build(), "rewrite sub-queries to semi-join")
.addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
@@ -50,10 +61,10 @@ object FlinkBatchProgram {
.build(), "sub-queries remove")
// convert RelOptTableImpl (which exists in SubQuery before) to FlinkRelOptTable
.addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
- .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
- .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
- .add(FlinkBatchRuleSets.TABLE_REF_RULES)
- .build(), "convert table references after sub-queries removed")
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkBatchRuleSets.TABLE_REF_RULES)
+ .build(), "convert table references after sub-queries removed")
.build()
)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala
index 08547f8..46e2a08 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala
@@ -46,6 +46,17 @@ object FlinkStreamProgram {
chainedProgram.addLast(
SUBQUERY_REWRITE,
FlinkGroupProgramBuilder.newBuilder[StreamOptimizeContext]
+ // rewrite RelTable before rewriting sub-queries
+ .addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkStreamRuleSets.TABLE_REF_RULES)
+ .build(), "convert table references before rewriting sub-queries to semi-join")
+ .addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkStreamRuleSets.SEMI_JOIN_RULES)
+ .build(), "rewrite sub-queries to semi-join")
.addProgram(FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/reuse/DeadlockBreakupProcessor.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/reuse/DeadlockBreakupProcessor.scala
index 7dbbad5..475c0c6 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/reuse/DeadlockBreakupProcessor.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/reuse/DeadlockBreakupProcessor.scala
@@ -182,12 +182,12 @@ class DeadlockBreakupProcessor {
override def visit(node: ExecNode[_, _]): Unit = {
super.visit(node)
node match {
- case hashJoin: BatchExecHashJoinBase =>
- val joinInfo = hashJoin.joinInfo
+ case hashJoin: BatchExecHashJoin =>
+ val joinInfo = hashJoin.getJoinInfo
val columns = if (hashJoin.leftIsBuild) joinInfo.rightKeys else joinInfo.leftKeys
val distribution = FlinkRelDistribution.hash(columns)
rewriteJoin(hashJoin, hashJoin.leftIsBuild, distribution)
- case nestedLoopJoin: BatchExecNestedLoopJoinBase =>
+ case nestedLoopJoin: BatchExecNestedLoopJoin =>
rewriteJoin(nestedLoopJoin, nestedLoopJoin.leftIsBuild, FlinkRelDistribution.ANY)
case _ => // do nothing
}
@@ -323,11 +323,11 @@ class DeadlockBreakupProcessor {
true
} else {
node match {
- case h: BatchExecHashJoinBase =>
+ case h: BatchExecHashJoin =>
val buildSideIndex = if (h.leftIsBuild) 0 else 1
val buildNode = h.getInputNodes.get(buildSideIndex)
checkJoinBuildSide(buildNode, idx, inputPath)
- case n: BatchExecNestedLoopJoinBase =>
+ case n: BatchExecNestedLoopJoin =>
val buildSideIndex = if (n.leftIsBuild) 0 else 1
val buildNode = n.getInputNodes.get(buildSideIndex)
checkJoinBuildSide(buildNode, idx, inputPath)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
index eeba4d5..e87b77e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
@@ -32,6 +32,12 @@ import scala.collection.JavaConverters._
object FlinkBatchRuleSets {
+ val SEMI_JOIN_RULES: RuleSet = RuleSets.ofList(
+ SimplifyFilterConditionRule.EXTENDED,
+ FlinkSubQueryRemoveRule.FILTER,
+ FlinkJoinPushExpressionsRule.INSTANCE
+ )
+
/**
* Convert sub-queries before query decorrelation.
*/
@@ -82,9 +88,9 @@ object FlinkBatchRuleSets {
*/
private val FILTER_RULES: RuleSet = RuleSets.ofList(
// push a filter into a join
- FilterJoinRule.FILTER_ON_JOIN,
+ FlinkFilterJoinRule.FILTER_ON_JOIN,
// push filter into the children of a join
- FilterJoinRule.JOIN,
+ FlinkFilterJoinRule.JOIN,
// push filter through an aggregation
FilterAggregateTransposeRule.INSTANCE,
// push a filter past a project
@@ -124,7 +130,8 @@ object FlinkBatchRuleSets {
ProjectFilterTransposeRule.INSTANCE,
// push a projection to the children of a join
// push all expressions to handle the time indicator correctly
- new ProjectJoinTransposeRule(PushProjector.ExprCondition.FALSE, RelFactories.LOGICAL_BUILDER),
+ new FlinkProjectJoinTransposeRule(
+ PushProjector.ExprCondition.FALSE, RelFactories.LOGICAL_BUILDER),
// merge projections
ProjectMergeRule.INSTANCE,
// remove identity project
@@ -156,7 +163,7 @@ object FlinkBatchRuleSets {
SortProjectTransposeRule.INSTANCE,
// join rules
- JoinPushExpressionsRule.INSTANCE,
+ FlinkJoinPushExpressionsRule.INSTANCE,
// remove union with only a single child
UnionEliminatorRule.INSTANCE,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index 794cdff..6ac853d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -32,6 +32,12 @@ import scala.collection.JavaConverters._
object FlinkStreamRuleSets {
+ val SEMI_JOIN_RULES: RuleSet = RuleSets.ofList(
+ SimplifyFilterConditionRule.EXTENDED,
+ FlinkSubQueryRemoveRule.FILTER,
+ FlinkJoinPushExpressionsRule.INSTANCE
+ )
+
/**
* Convert sub-queries before query decorrelation.
*/
@@ -85,9 +91,9 @@ object FlinkStreamRuleSets {
*/
private val FILTER_RULES: RuleSet = RuleSets.ofList(
// push a filter into a join
- FilterJoinRule.FILTER_ON_JOIN,
+ FlinkFilterJoinRule.FILTER_ON_JOIN,
// push filter into the children of a join
- FilterJoinRule.JOIN,
+ FlinkFilterJoinRule.JOIN,
// push filter through an aggregation
FilterAggregateTransposeRule.INSTANCE,
// push a filter past a project
@@ -127,7 +133,8 @@ object FlinkStreamRuleSets {
ProjectFilterTransposeRule.INSTANCE,
// push a projection to the children of a join
// push all expressions to handle the time indicator correctly
- new ProjectJoinTransposeRule(PushProjector.ExprCondition.FALSE, RelFactories.LOGICAL_BUILDER),
+ new FlinkProjectJoinTransposeRule(
+ PushProjector.ExprCondition.FALSE, RelFactories.LOGICAL_BUILDER),
// merge projections
ProjectMergeRule.INSTANCE,
// remove identity project
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkSubQueryRemoveRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkSubQueryRemoveRule.scala
new file mode 100644
index 0000000..c55fa8f
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkSubQueryRemoveRule.scala
@@ -0,0 +1,459 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkRelFactories}
+
+import com.google.common.collect.ImmutableList
+import org.apache.calcite.plan.RelOptRule._
+import org.apache.calcite.plan.RelOptUtil.Logic
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand}
+import org.apache.calcite.rel.core.{Filter, JoinRelType}
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalJoin, LogicalProject}
+import org.apache.calcite.rel.{RelNode, RelShuttleImpl}
+import org.apache.calcite.rex._
+import org.apache.calcite.sql.SqlKind
+import org.apache.calcite.tools.{RelBuilder, RelBuilderFactory}
+import org.apache.calcite.util.Util
+
+import java.util
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable
+
+/**
+ * Planner rule that converts IN and EXISTS into semi-join,
+ * converts NOT IN and NOT EXISTS into anti-join.
+ *
+ * <p>Sub-queries are represented by [[RexSubQuery]] expressions.
+ *
+ * <p>A sub-query may or may not be correlated. If a sub-query is correlated,
+ * the wrapped [[RelNode]] will contain a [[RexCorrelVariable]] before the rewrite,
+ * and the product of the rewrite will be a [[org.apache.calcite.rel.core.Join]]
+ * with SEMI or ANTI join type.
+ */
+class FlinkSubQueryRemoveRule(
+ operand: RelOptRuleOperand,
+ relBuilderFactory: RelBuilderFactory,
+ description: String)
+ extends RelOptRule(operand, relBuilderFactory, description) {
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val filter: Filter = call.rel(0)
+ val condition = filter.getCondition
+
+ if (hasUnsupportedSubQuery(condition)) {
+ // has some unsupported subquery, such as: subquery connected with OR
+ // select * from t1 where t1.a > 10 or t1.b in (select t2.c from t2)
+ // TODO supports ExistenceJoin
+ return
+ }
+
+ val subQueryCall = findSubQuery(condition)
+ if (subQueryCall.isEmpty) {
+ // ignore scalar query
+ return
+ }
+
+ val decorrelate = SubQueryDecorrelator.decorrelateQuery(filter)
+ if (decorrelate == null) {
+ // can't handle the query
+ return
+ }
+
+ val relBuilder = call.builder.asInstanceOf[FlinkRelBuilder]
+ relBuilder.push(filter.getInput) // push join left
+
+ val newCondition = handleSubQuery(subQueryCall.get, condition, relBuilder, decorrelate)
+ newCondition match {
+ case Some(c) =>
+ if (hasCorrelatedExpressions(c)) {
+ // some correlated expressions can not be replaced in this rule,
+ // so we must keep the VariablesSet for decorrelating later in new filter
+ // RelBuilder.filter can not create Filter with VariablesSet arg
+ val newFilter = filter.copy(filter.getTraitSet, relBuilder.build(), c)
+ relBuilder.push(newFilter)
+ } else {
+ // all correlated expressions are replaced,
+ // so we can create a new filter without any VariablesSet
+ relBuilder.filter(c)
+ }
+ relBuilder.project(fields(relBuilder, filter.getRowType.getFieldCount))
+ call.transformTo(relBuilder.build)
+ case _ => // do nothing
+ }
+ }
+
+ def handleSubQuery(
+ subQueryCall: RexCall,
+ condition: RexNode,
+ relBuilder: FlinkRelBuilder,
+ decorrelate: SubQueryDecorrelator.Result): Option[RexNode] = {
+ val logic = LogicVisitor.find(Logic.TRUE, ImmutableList.of(condition), subQueryCall)
+ if (logic != Logic.TRUE) {
+ // this should not happen, none unsupported SubQuery could not reach here
+ // this is just for double-check
+ return None
+ }
+
+ val target = apply(subQueryCall, relBuilder, decorrelate)
+ if (target.isEmpty) {
+ return None
+ }
+
+ val newCondition = replaceSubQuery(condition, subQueryCall, target.get)
+ val nextSubQueryCall = findSubQuery(newCondition)
+ nextSubQueryCall match {
+ case Some(subQuery) => handleSubQuery(subQuery, newCondition, relBuilder, decorrelate)
+ case _ => Some(newCondition)
+ }
+ }
+
+ private def apply(
+ subQueryCall: RexCall,
+ relBuilder: FlinkRelBuilder,
+ decorrelate: SubQueryDecorrelator.Result): Option[RexNode] = {
+
+ val (subQuery: RexSubQuery, withNot: Boolean) = subQueryCall match {
+ case s: RexSubQuery => (s, false)
+ case c: RexCall => (c.operands.head, true)
+ }
+
+ val equivalent = decorrelate.getSubQueryEquivalent(subQuery)
+
+ subQuery.getKind match {
+ // IN and NOT IN
+ //
+ // NOT IN is a NULL-aware (left) anti join e.g. col NOT IN expr Construct the condition.
+ // A NULL in one of the conditions is regarded as a positive result;
+ // such a row will be filtered out by the Anti-Join operator.
+ //
+ // Rewrite logic for NOT IN:
+ // Expand the NOT IN expression with the NULL-aware semantic to its full form.
+ // That is from:
+ // (a1,a2,...) = (b1,b2,...)
+ // to
+ // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ...
+ //
+ // After that, add back the correlated join predicate(s) in the subquery
+ // Example:
+ // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1)
+ // will have the final conditions in the ANTI JOIN as
+ // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
+ case SqlKind.IN =>
+ // TODO:
+ // Calcite does not support project with correlated expressions.
+ // e.g.
+ // SELECT b FROM l WHERE (
+ // CASE WHEN a IN (SELECT i FROM t1 WHERE l.b = t1.j) THEN 1 ELSE 2 END)
+ // IN (SELECT d FROM r)
+ //
+ // we can not create project with VariablesSet arg, and
+ // the result of RelDecorrelator is also wrong.
+ if (hasCorrelatedExpressions(subQuery.getOperands: _*)) {
+ return None
+ }
+
+ val (newRight, joinCondition) = if (equivalent != null) {
+ // IN has correlation variables
+ (equivalent.getKey, Some(equivalent.getValue))
+ } else {
+ // IN has no correlation variables
+ (subQuery.rel, None)
+ }
+ // adds projection if the operands of IN contains non-RexInputRef nodes
+ // e.g. SELECT * FROM l WHERE a + 1 IN (SELECT c FROM r)
+ val (newOperands, newJoinCondition) =
+ handleSubQueryOperands(subQuery, joinCondition, relBuilder)
+ val leftFieldCount = relBuilder.peek().getRowType.getFieldCount
+
+ relBuilder.push(newRight) // push join right
+
+ val joinConditions = newOperands
+ .zip(relBuilder.fields())
+ .map { case (op, f) =>
+ val inCondition = relBuilder.equals(op, RexUtil.shift(f, leftFieldCount))
+ if (withNot) {
+ relBuilder.or(inCondition, relBuilder.isNull(inCondition))
+ } else {
+ inCondition
+ }
+ }.toBuffer
+
+ newJoinCondition.foreach(joinConditions += _)
+
+ if (withNot) {
+ relBuilder.join(JoinRelType.ANTI, joinConditions)
+ } else {
+ relBuilder.join(JoinRelType.SEMI, joinConditions)
+ }
+ Some(relBuilder.literal(true))
+
+ // EXISTS and NOT EXISTS
+ case SqlKind.EXISTS =>
+ val joinCondition = if (equivalent != null) {
+ // EXISTS has correlation variables
+ relBuilder.push(equivalent.getKey) // push join right
+ require(equivalent.getValue != null)
+ equivalent.getValue
+ } else {
+ // EXISTS has no correlation variables
+ //
+ // e.g. (table `l` has two columns: `a`, `b`, and table `r` has two columns: `c`, `d`)
+ // SELECT * FROM l WHERE EXISTS (SELECT * FROM r)
+ // which can be converted to:
+ //
+ // LogicalProject(a=[$0], b=[$1])
+ // LogicalJoin(condition=[$2], joinType=[semi])
+ // LogicalTableScan(table=[[builtin, default, l]])
+ // LogicalProject($f0=[IS NOT NULL($0)])
+ // LogicalAggregate(group=[{}], m=[MIN($0)])
+ // LogicalProject(i=[true])
+ // LogicalTableScan(table=[[builtin, default, r]])
+ //
+ // MIN($0) will return null when table `r` is empty,
+ // so add LogicalProject($f0=[IS NOT NULL($0)]) to check null value
+ val leftFieldCount = relBuilder.peek().getRowType.getFieldCount
+ relBuilder.push(subQuery.rel) // push join right
+ // adds LogicalProject(i=[true]) to join right
+ relBuilder.project(relBuilder.alias(relBuilder.literal(true), "i"))
+ // adds LogicalAggregate(group=[{}], agg#0=[MIN($0)]) to join right
+ relBuilder.aggregate(relBuilder.groupKey(), relBuilder.min("m", relBuilder.field(0)))
+ // adds LogicalProject($f0=[IS NOT NULL($0)]) to check null value
+ relBuilder.project(relBuilder.isNotNull(relBuilder.field(0)))
+ val fieldType = relBuilder.peek().getRowType.getFieldList.get(0).getType
+ // join condition references project result directly
+ new RexInputRef(leftFieldCount, fieldType)
+ }
+
+ if (withNot) {
+ relBuilder.join(JoinRelType.ANTI, joinCondition)
+ } else {
+ relBuilder.join(JoinRelType.SEMI, joinCondition)
+ }
+ Some(relBuilder.literal(true))
+
+ case _ => None
+ }
+ }
+
+ private def fields(builder: RelBuilder, fieldCount: Int): util.List[RexNode] = {
+ val projects: util.List[RexNode] = new util.ArrayList[RexNode]()
+ (0 until fieldCount).foreach(i => projects.add(builder.field(i)))
+ projects
+ }
+
+ private def isScalarQuery(n: RexNode): Boolean = n.isA(SqlKind.SCALAR_QUERY)
+
+ private def findSubQuery(node: RexNode): Option[RexCall] = {
+ val subQueryFinder = new RexVisitorImpl[Unit](true) {
+ override def visitSubQuery(subQuery: RexSubQuery): Unit = {
+ if (!isScalarQuery(subQuery)) {
+ throw new Util.FoundOne(subQuery)
+ }
+ }
+
+ override def visitCall(call: RexCall): Unit = {
+ call.getKind match {
+ case SqlKind.NOT if call.operands.head.isInstanceOf[RexSubQuery] =>
+ if (!isScalarQuery(call.operands.head)) {
+ throw new Util.FoundOne(call)
+ }
+ case _ =>
+ super.visitCall(call)
+ }
+ }
+ }
+
+ try {
+ node.accept(subQueryFinder)
+ None
+ } catch {
+ case e: Util.FoundOne => Some(e.getNode.asInstanceOf[RexCall])
+ }
+ }
+
+ private def replaceSubQuery(
+ condition: RexNode,
+ oldSubQueryCall: RexCall,
+ replacement: RexNode): RexNode = {
+ condition.accept(new RexShuttle() {
+ override def visitSubQuery(subQuery: RexSubQuery): RexNode = {
+ if (RexUtil.eq(subQuery, oldSubQueryCall)) replacement else subQuery
+ }
+
+ override def visitCall(call: RexCall): RexNode = {
+ call.getKind match {
+ case SqlKind.NOT if call.operands.head.isInstanceOf[RexSubQuery] =>
+ if (RexUtil.eq(call, oldSubQueryCall)) replacement else call
+ case _ =>
+ super.visitCall(call)
+ }
+ }
+ })
+ }
+
+ /**
+ * Adds projection if the operands of a SubQuery contains non-RexInputRef nodes,
+ * and returns SubQuery's new operands and new join condition with new index.
+ *
+ * e.g. SELECT * FROM l WHERE a + 1 IN (SELECT c FROM r)
+ * We will add projection as SEMI join left input, the added projection will pass along fields
+ * from the input, and add `a + 1` as new field.
+ */
+ private def handleSubQueryOperands(
+ subQuery: RexSubQuery,
+ joinCondition: Option[RexNode],
+ relBuilder: RelBuilder): (Seq[RexNode], Option[RexNode]) = {
+ val operands = subQuery.getOperands
+ // operands is empty or all operands are RexInputRef
+ if (operands.isEmpty || operands.forall(_.isInstanceOf[RexInputRef])) {
+ return (operands, joinCondition)
+ }
+
+ val rexBuilder = relBuilder.getRexBuilder
+ val oldLeftNode = relBuilder.peek()
+ val oldLeftFieldCount = oldLeftNode.getRowType.getFieldCount
+ val newLeftProjects = mutable.ListBuffer[RexNode]()
+ val newOperandIndices = mutable.ListBuffer[Int]()
+ (0 until oldLeftFieldCount).foreach(newLeftProjects += rexBuilder.makeInputRef(oldLeftNode, _))
+ operands.foreach { o =>
+ var index = newLeftProjects.indexOf(o)
+ if (index < 0) {
+ index = newLeftProjects.size
+ newLeftProjects += o
+ }
+ newOperandIndices += index
+ }
+
+ // adjust join condition after adds new projection
+ val newJoinCondition = if (joinCondition.isDefined) {
+ val offset = newLeftProjects.size - oldLeftFieldCount
+ Some(RexUtil.shift(joinCondition.get, oldLeftFieldCount, offset))
+ } else {
+ None
+ }
+
+ relBuilder.project(newLeftProjects) // push new join left
+ val newOperands = newOperandIndices.map(rexBuilder.makeInputRef(relBuilder.peek(), _))
+ (newOperands, newJoinCondition)
+ }
+
+ /**
+ * Check the condition whether contains unsupported SubQuery.
+ *
+ * Now, we only support single SubQuery or SubQuery connected with AND.
+ */
+ private def hasUnsupportedSubQuery(condition: RexNode): Boolean = {
+ val visitor = new RexVisitorImpl[Unit](true) {
+ val stack: util.Deque[SqlKind] = new util.ArrayDeque[SqlKind]()
+
+ private def checkAndConjunctions(call: RexCall): Unit = {
+ if (stack.exists(_ ne SqlKind.AND)) {
+ throw new Util.FoundOne(call)
+ }
+ }
+
+ override def visitSubQuery(subQuery: RexSubQuery): Unit = {
+ // ignore scalar query
+ if (!isScalarQuery(subQuery)) {
+ checkAndConjunctions(subQuery)
+ }
+ }
+
+ override def visitCall(call: RexCall): Unit = {
+ call.getKind match {
+ case SqlKind.NOT if call.operands.head.isInstanceOf[RexSubQuery] =>
+ // ignore scalar query
+ if (!isScalarQuery(call.operands.head)) {
+ checkAndConjunctions(call)
+ }
+ case _ =>
+ stack.push(call.getKind)
+ call.operands.foreach(_.accept(this))
+ stack.pop()
+ }
+ }
+ }
+
+ try {
+ condition.accept(visitor)
+ false
+ } catch {
+ case _: Util.FoundOne => true
+ }
+ }
+
+ /**
+ * Check nodes' SubQuery whether contains correlated expressions.
+ */
+ private def hasCorrelatedExpressions(nodes: RexNode*): Boolean = {
+ val relShuttle = new RelShuttleImpl() {
+ private val corVarFinder = new RexVisitorImpl[Unit](true) {
+ override def visitCorrelVariable(corVar: RexCorrelVariable): Unit = {
+ throw new Util.FoundOne(corVar)
+ }
+ }
+
+ override def visit(filter: LogicalFilter): RelNode = {
+ filter.getCondition.accept(corVarFinder)
+ super.visit(filter)
+ }
+
+ override def visit(join: LogicalJoin): RelNode = {
+ join.getCondition.accept(corVarFinder)
+ super.visit(join)
+ }
+
+ override def visit(project: LogicalProject): RelNode = {
+ project.getProjects.foreach(_.accept(corVarFinder))
+ super.visit(project)
+ }
+ }
+
+ val subQueryFinder = new RexVisitorImpl[Unit](true) {
+ override def visitSubQuery(subQuery: RexSubQuery): Unit = {
+ subQuery.rel.accept(relShuttle)
+ }
+ }
+
+ nodes.foldLeft(false) {
+ case (found, c) =>
+ if (!found) {
+ try {
+ c.accept(subQueryFinder)
+ false
+ } catch {
+ case _: Util.FoundOne => true
+ }
+ } else {
+ true
+ }
+ }
+ }
+}
+
+object FlinkSubQueryRemoveRule {
+
+ val FILTER = new FlinkSubQueryRemoveRule(
+ operandJ(classOf[Filter], null, RexUtil.SubQueryFinder.FILTER_PREDICATE, any),
+ FlinkRelFactories.FLINK_REL_BUILDER,
+ "FlinkSubQueryRemoveRule:Filter")
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/SimplifyFilterConditionRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/SimplifyFilterConditionRule.scala
new file mode 100644
index 0000000..92b96b2
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/SimplifyFilterConditionRule.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.flink.table.plan.util.FlinkRexUtil
+
+import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.core.Filter
+import org.apache.calcite.rel.logical.LogicalFilter
+import org.apache.calcite.rel.{RelNode, RelShuttleImpl}
+import org.apache.calcite.rex._
+
+/**
+ * Planner rule that apply various simplifying transformations on filter condition.
+ *
+ * if `simplifySubQuery` is true, this rule will also simplify the filter condition
+ * in [[RexSubQuery]].
+ */
+class SimplifyFilterConditionRule(
+ simplifySubQuery: Boolean,
+ description: String)
+ extends RelOptRule(
+ operand(classOf[Filter], any()),
+ description) {
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val filter: Filter = call.rel(0)
+ val changed = Array(false)
+ val newFilter = simplify(filter, changed)
+ newFilter match {
+ case Some(f) =>
+ call.transformTo(f)
+ call.getPlanner.setImportance(filter, 0.0)
+ case _ => // do nothing
+ }
+ }
+
+ def simplify(filter: Filter, changed: Array[Boolean]): Option[Filter] = {
+ val condition = if (simplifySubQuery) {
+ simplifyFilterConditionInSubQuery(filter.getCondition, changed)
+ } else {
+ filter.getCondition
+ }
+
+ val rexBuilder = filter.getCluster.getRexBuilder
+ val simplifiedCondition = FlinkRexUtil.simplify(rexBuilder, condition)
+ val newCondition = RexUtil.pullFactors(rexBuilder, simplifiedCondition)
+
+ if (!changed.head && !RexUtil.eq(condition, newCondition)) {
+ changed(0) = true
+ }
+
+ // just replace modified RexNode
+ if (changed.head) {
+ Some(filter.copy(filter.getTraitSet, filter.getInput, newCondition))
+ } else {
+ None
+ }
+ }
+
+ def simplifyFilterConditionInSubQuery(condition: RexNode, changed: Array[Boolean]): RexNode = {
+ condition.accept(new RexShuttle() {
+ override def visitSubQuery(subQuery: RexSubQuery): RexNode = {
+ val newRel = subQuery.rel.accept(new RelShuttleImpl() {
+ override def visit(filter: LogicalFilter): RelNode = {
+ simplify(filter, changed).getOrElse(filter)
+ }
+ })
+ if (changed.head) {
+ subQuery.clone(newRel)
+ } else {
+ subQuery
+ }
+ }
+ })
+ }
+
+}
+
+object SimplifyFilterConditionRule {
+ val INSTANCE = new SimplifyFilterConditionRule(
+ false, "SimplifyFilterConditionRule")
+
+ val EXTENDED = new SimplifyFilterConditionRule(
+ true, "SimplifyFilterConditionRule:simplifySubQuery")
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecHashJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecHashJoinRule.scala
index a4a6bc2..06f187e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecHashJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecHashJoinRule.scala
@@ -25,12 +25,11 @@ import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalJoin
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashJoin
-import org.apache.flink.table.runtime.join.FlinkJoinType
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{Join, SemiJoin}
+import org.apache.calcite.rel.core.{Join, JoinRelType}
import org.apache.calcite.util.ImmutableIntList
import java.util
@@ -42,11 +41,11 @@ import scala.collection.JavaConversions._
* if there exists at least one equal-join condition and
* ShuffleHashJoin or BroadcastHashJoin are enabled.
*/
-class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
+class BatchExecHashJoinRule
extends RelOptRule(
- operand(joinClass,
+ operand(classOf[FlinkLogicalJoin],
operand(classOf[RelNode], any)),
- s"BatchExecHashJoinRule_${joinClass.getSimpleName}")
+ "BatchExecHashJoinRule")
with BatchExecJoinRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
@@ -61,10 +60,9 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
val isShuffleHashJoinEnabled = tableConfig.isOperatorEnabled(OperatorType.ShuffleHashJoin)
val isBroadcastHashJoinEnabled = tableConfig.isOperatorEnabled(OperatorType.BroadcastHashJoin)
- val joinType = getFlinkJoinType(join)
val leftSize = binaryRowRelNodeSize(join.getLeft)
val rightSize = binaryRowRelNodeSize(join.getRight)
- val (isBroadcast, _) = canBroadcast(joinType, leftSize, rightSize, tableConfig)
+ val (isBroadcast, _) = canBroadcast(join.getJoinType, leftSize, rightSize, tableConfig)
// TODO use shuffle hash join if isBroadcast is true and isBroadcastHashJoinEnabled is false ?
if (isBroadcast) isBroadcastHashJoinEnabled else isShuffleHashJoinEnabled
@@ -74,10 +72,21 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
val tableConfig = call.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val join: Join = call.rel(0)
val joinInfo = join.analyzeCondition
- val joinType = getFlinkJoinType(join)
+ val joinType = join.getJoinType
val left = join.getLeft
- val right = join.getRight
+ val (right, tryDistinctBuildRow) = joinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // We can do a distinct to buildSide(right) when semi join.
+ val distinctKeys = 0 until join.getRight.getRowType.getFieldCount
+ val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys)
+ if (useBuildDistinct) {
+ (addLocalDistinctAgg(join.getRight, distinctKeys, call.builder()), true)
+ } else {
+ (join.getRight, false)
+ }
+ case _ => (join.getRight, false)
+ }
val leftSize = binaryRowRelNodeSize(left)
val rightSize = binaryRowRelNodeSize(right)
@@ -88,8 +97,8 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
leftIsBroadcast
} else if (leftSize == null || rightSize == null || leftSize == rightSize) {
// use left to build hash table if leftSize or rightSize is unknown or equal size.
- // choose right to build if join is semiJoin.
- !join.isInstanceOf[SemiJoin]
+ // choose right to build if join is SEMI/ANTI.
+ !join.getJoinType.projectsRight
} else {
leftSize < rightSize
}
@@ -107,7 +116,8 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
join.getCondition,
join.getJoinType,
leftIsBuild,
- isBroadcast)
+ isBroadcast,
+ tryDistinctBuildRow)
call.transformTo(newJoin)
}
@@ -155,7 +165,7 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
* as broadcast side, false else.
*/
private def canBroadcast(
- joinType: FlinkJoinType,
+ joinType: JoinRelType,
leftSize: JDouble,
rightSize: JDouble,
tableConfig: TableConfig): (Boolean, Boolean) = {
@@ -166,17 +176,17 @@ class BatchExecHashJoinRule(joinClass: Class[_ <: Join])
val threshold = tableConfig.getConf.getLong(
PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD)
joinType match {
- case FlinkJoinType.LEFT => (rightSize <= threshold, false)
- case FlinkJoinType.RIGHT => (leftSize <= threshold, true)
- case FlinkJoinType.FULL => (false, false)
- case FlinkJoinType.INNER =>
+ case JoinRelType.LEFT => (rightSize <= threshold, false)
+ case JoinRelType.RIGHT => (leftSize <= threshold, true)
+ case JoinRelType.FULL => (false, false)
+ case JoinRelType.INNER =>
(leftSize <= threshold || rightSize <= threshold, leftSize < rightSize)
// left side cannot be used as build side in SEMI/ANTI join.
- case FlinkJoinType.SEMI | FlinkJoinType.ANTI => (rightSize <= threshold, false)
+ case JoinRelType.SEMI | JoinRelType.ANTI => (rightSize <= threshold, false)
}
}
}
object BatchExecHashJoinRule {
- val INSTANCE = new BatchExecHashJoinRule(classOf[FlinkLogicalJoin])
+ val INSTANCE = new BatchExecHashJoinRule
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecJoinRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecJoinRuleBase.scala
index 87b0567..3c4e0d7 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecJoinRuleBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecJoinRuleBase.scala
@@ -19,16 +19,15 @@
package org.apache.flink.table.plan.rules.physical.batch
import org.apache.flink.table.JDouble
+import org.apache.flink.table.api.PlannerConfigOptions
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.logical.FlinkLogicalJoin
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLocalHashAggregate
-import org.apache.flink.table.plan.util.{FlinkRelMdUtil, JoinTypeUtil}
-import org.apache.flink.table.runtime.join.FlinkJoinType
+import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.RelOptRule
import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.Join
import org.apache.calcite.tools.RelBuilder
+import org.apache.calcite.util.ImmutableBitSet
trait BatchExecJoinRuleBase {
@@ -52,9 +51,21 @@ trait BatchExecJoinRuleBase {
Seq())
}
- def getFlinkJoinType(join: Join): FlinkJoinType = join match {
- case j: FlinkLogicalJoin => JoinTypeUtil.getFlinkJoinType(j)
- case _ => throw new IllegalArgumentException(s"Illegal join node: ${join.getRelTypeName}")
+ def chooseSemiBuildDistinct(
+ buildRel: RelNode,
+ distinctKeys: Seq[Int]): Boolean = {
+ val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(buildRel)
+ val mq = buildRel.getCluster.getMetadataQuery
+ val ratioConf = tableConfig.getConf.getDouble(
+ PlannerConfigOptions.SQL_OPTIMIZER_SEMI_JOIN_BUILD_DISTINCT_NDV_RATIO)
+ val inputRows = mq.getRowCount(buildRel)
+ val ndvOfGroupKey = mq.getDistinctRowCount(
+ buildRel, ImmutableBitSet.of(distinctKeys: _*), null)
+ if (ndvOfGroupKey == null) {
+ false
+ } else {
+ ndvOfGroupKey / inputRows < ratioConf
+ }
}
private[flink] def binaryRowRelNodeSize(relNode: RelNode): JDouble = {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala
index 3ad1469..ebe2459 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala
@@ -25,17 +25,17 @@ import org.apache.flink.table.plan.nodes.physical.batch.BatchExecNestedLoopJoin
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{Join, JoinRelType, SemiJoin}
+import org.apache.calcite.rel.core.{Join, JoinRelType}
/**
* Rule that converts [[FlinkLogicalJoin]] to [[BatchExecNestedLoopJoin]]
* if NestedLoopJoin is enabled.
*/
-class BatchExecNestedLoopJoinRule(joinClass: Class[_ <: Join])
+class BatchExecNestedLoopJoinRule
extends RelOptRule(
- operand(joinClass,
+ operand(classOf[FlinkLogicalJoin],
operand(classOf[RelNode], any)),
- s"BatchExecNestedLoopJoinRule_${joinClass.getSimpleName}")
+ "BatchExecNestedLoopJoinRule")
with BatchExecJoinRuleBase
with BatchExecNestedLoopJoinRuleBase {
@@ -47,16 +47,24 @@ class BatchExecNestedLoopJoinRule(joinClass: Class[_ <: Join])
override def onMatch(call: RelOptRuleCall): Unit = {
val join: Join = call.rel(0)
val left = join.getLeft
- val right = join.getRight
+ val right = join.getJoinType match {
+ case JoinRelType.SEMI | JoinRelType.ANTI =>
+ // We can do a distinct to buildSide(right) when semi join.
+ val distinctKeys = 0 until join.getRight.getRowType.getFieldCount
+ val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys)
+ if (useBuildDistinct) {
+ addLocalDistinctAgg(join.getRight, distinctKeys, call.builder())
+ } else {
+ join.getRight
+ }
+ case _ => join.getRight
+ }
val leftIsBuild = isLeftBuild(join, left, right)
val newJoin = createNestedLoopJoin(join, left, right, leftIsBuild, singleRowJoin = false)
call.transformTo(newJoin)
}
private def isLeftBuild(join: Join, left: RelNode, right: RelNode): Boolean = {
- if (join.isInstanceOf[SemiJoin]) {
- return false
- }
join.getJoinType match {
case JoinRelType.LEFT => false
case JoinRelType.RIGHT => true
@@ -69,10 +77,11 @@ class BatchExecNestedLoopJoinRule(joinClass: Class[_ <: Join])
} else {
leftSize <= rightSize
}
+ case JoinRelType.SEMI | JoinRelType.ANTI => false
}
}
}
object BatchExecNestedLoopJoinRule {
- val INSTANCE: RelOptRule = new BatchExecNestedLoopJoinRule(classOf[FlinkLogicalJoin])
+ val INSTANCE: RelOptRule = new BatchExecNestedLoopJoinRule
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSingleRowJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSingleRowJoinRule.scala
index d44254a..52afdd7 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSingleRowJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSingleRowJoinRule.scala
@@ -21,7 +21,6 @@ package org.apache.flink.table.plan.rules.physical.batch
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalJoin
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecNestedLoopJoin
-import org.apache.flink.table.runtime.join.FlinkJoinType
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
@@ -33,25 +32,23 @@ import org.apache.calcite.rel.core._
* Rule that converts [[FlinkLogicalJoin]] to [[BatchExecNestedLoopJoin]]
* if one of join input sides returns at most a single row.
*/
-class BatchExecSingleRowJoinRule(joinClass: Class[_ <: Join])
+class BatchExecSingleRowJoinRule
extends ConverterRule(
- joinClass,
+ classOf[FlinkLogicalJoin],
FlinkConventions.LOGICAL,
FlinkConventions.BATCH_PHYSICAL,
- s"BatchExecSingleRowJoinRule_${joinClass.getSimpleName}")
+ "BatchExecSingleRowJoinRule")
with BatchExecJoinRuleBase
with BatchExecNestedLoopJoinRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val join: Join = call.rel(0)
- val joinType = getFlinkJoinType(join)
- joinType match {
- case FlinkJoinType.INNER | FlinkJoinType.FULL =>
+ join.getJoinType match {
+ case JoinRelType.INNER | JoinRelType.FULL =>
isSingleRow(join.getLeft) || isSingleRow(join.getRight)
- case FlinkJoinType.LEFT if isSingleRow(join.getRight) => true
- case FlinkJoinType.RIGHT if isSingleRow(join.getLeft) => true
- case FlinkJoinType.SEMI if isSingleRow(join.getRight) => true
- case FlinkJoinType.ANTI if isSingleRow(join.getRight) => true
+ case JoinRelType.LEFT => isSingleRow(join.getRight)
+ case JoinRelType.RIGHT => isSingleRow(join.getLeft)
+ case JoinRelType.SEMI | JoinRelType.ANTI => isSingleRow(join.getRight)
case _ => false
}
}
@@ -85,5 +82,5 @@ class BatchExecSingleRowJoinRule(joinClass: Class[_ <: Join])
}
object BatchExecSingleRowJoinRule {
- val INSTANCE: RelOptRule = new BatchExecSingleRowJoinRule(classOf[FlinkLogicalJoin])
+ val INSTANCE: RelOptRule = new BatchExecSingleRowJoinRule
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSortMergeJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSortMergeJoinRule.scala
index 992a2ff..3abab91 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSortMergeJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecSortMergeJoinRule.scala
@@ -38,11 +38,11 @@ import scala.collection.JavaConversions._
* Rule that converts [[FlinkLogicalJoin]] to [[BatchExecSortMergeJoin]]
* if there exists at least one equal-join condition and SortMergeJoin is enabled.
*/
-class BatchExecSortMergeJoinRule(joinClass: Class[_ <: Join])
+class BatchExecSortMergeJoinRule
extends RelOptRule(
- operand(joinClass,
+ operand(classOf[FlinkLogicalJoin],
operand(classOf[RelNode], any)),
- s"BatchExecSortMergeJoinRule_${joinClass.getSimpleName}")
+ "BatchExecSortMergeJoinRule")
with BatchExecJoinRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
@@ -91,7 +91,7 @@ class BatchExecSortMergeJoinRule(joinClass: Class[_ <: Join])
val providedTraitSet = call.getPlanner
.emptyTraitSet()
.replace(FlinkConventions.BATCH_PHYSICAL)
- val newJoin = new BatchExecSortMergeJoin(
+ val newJoin = new BatchExecSortMergeJoin(
join.getCluster,
providedTraitSet,
newLeft,
@@ -139,5 +139,5 @@ class BatchExecSortMergeJoinRule(joinClass: Class[_ <: Join])
}
object BatchExecSortMergeJoinRule {
- val INSTANCE: RelOptRule = new BatchExecSortMergeJoinRule(classOf[FlinkLogicalJoin])
+ val INSTANCE: RelOptRule = new BatchExecSortMergeJoinRule
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
index 5ccbbcf..0f12ae8 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
@@ -47,11 +47,15 @@ class StreamExecJoinRule
override def matches(call: RelOptRuleCall): Boolean = {
val join: FlinkLogicalJoin = call.rel(0)
- val tableConfig = call.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
- val joinRowType = join.getRowType
+ if (!join.getJoinType.projectsRight) {
+ // SEMI/ANTI join always converts to StreamExecJoin now
+ return true
+ }
// TODO check LHS or RHS are FlinkLogicalSnapshot
+ val joinRowType = join.getRowType
+ val tableConfig = call.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate(
join.getCondition,
join.getLeft.getRowType.getFieldCount,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecWindowJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecWindowJoinRule.scala
index 2ddb7c1..f0fcd86 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecWindowJoinRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecWindowJoinRule.scala
@@ -24,7 +24,7 @@ import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalJoin
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecWindowJoin
-import org.apache.flink.table.plan.util.WindowJoinUtil
+import org.apache.flink.table.plan.util.{FlinkRelOptUtil, WindowJoinUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
@@ -35,7 +35,7 @@ import java.util
import scala.collection.JavaConversions._
/**
- * Rule that converts [[FlinkLogicalJoin]] with window bounds in join condition
+ * Rule that converts non-SEMI/ANTI [[FlinkLogicalJoin]] with window bounds in join condition
* to [[StreamExecWindowJoin]].
*/
class StreamExecWindowJoinRule
@@ -46,21 +46,23 @@ class StreamExecWindowJoinRule
"StreamExecWindowJoinRule") {
override def matches(call: RelOptRuleCall): Boolean = {
- val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin]
+ val join: FlinkLogicalJoin = call.rel(0)
val joinRowType = join.getRowType
val joinInfo = join.analyzeCondition()
// joins require an equi-condition or a conjunctive predicate with at least one equi-condition
- if (joinInfo.pairs().isEmpty) {
+ // TODO support SEMI/ANTI join
+ if (!join.getJoinType.projectsRight || joinInfo.pairs().isEmpty) {
return false
}
+ val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(join)
val (windowBounds, _) = WindowJoinUtil.extractWindowBoundsFromPredicate(
join.getCondition,
join.getLeft.getRowType.getFieldCount,
joinRowType,
join.getCluster.getRexBuilder,
- TableConfig.DEFAULT)
+ tableConfig)
if (windowBounds.isDefined) {
if (windowBounds.get.isEventTime) {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
index e6275b8..18edfba 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
@@ -30,7 +30,7 @@ import org.apache.flink.table.typeutils.BinaryRowSerializer
import com.google.common.collect.ImmutableList
import org.apache.calcite.avatica.util.TimeUnitRange._
import org.apache.calcite.plan.RelOptUtil
-import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, JoinInfo, SemiJoin}
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, Join, JoinInfo, JoinRelType}
import org.apache.calcite.rel.metadata.{RelMdUtil, RelMetadataQuery}
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil, RexVisitorImpl}
@@ -49,6 +49,20 @@ import scala.collection.mutable
*/
object FlinkRelMdUtil {
+ /** Returns an estimate of the number of rows returned by a SEMI/ANTI [[Join]]. */
+ def getSemiAntiJoinRowCount(mq: RelMetadataQuery, left: RelNode, right: RelNode,
+ joinType: JoinRelType, condition: RexNode, isAnti: Boolean): JDouble = {
+ val leftCount = mq.getRowCount(left)
+ if (leftCount == null) {
+ return null
+ }
+ var selectivity = RexUtil.getSelectivity(condition)
+ if (isAnti) {
+ selectivity = 1d - selectivity
+ }
+ leftCount * selectivity
+ }
+
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a semi-join/anti-join. This can be added to a filter to simulate the
@@ -56,26 +70,24 @@ object FlinkRelMdUtil {
* plan since it has no physical implementation.
*
* @param mq instance of metadata query
- * @param rel the semiJoin or antiJoin of interest
+ * @param rel the SEMI/ANTI join of interest
* @return constructed rexNode
*/
- def makeSemiJoinSelectivityRexNode(
- mq: RelMetadataQuery,
- rel: SemiJoin): RexNode = {
+ def makeSemiAntiJoinSelectivityRexNode(mq: RelMetadataQuery, rel: Join): RexNode = {
+ require(rel.getJoinType == JoinRelType.SEMI || rel.getJoinType == JoinRelType.ANTI)
val joinInfo = rel.analyzeCondition()
val rexBuilder = rel.getCluster.getRexBuilder
- makeSemiJoinSelectivityRexNode(
- mq, joinInfo, rel.getLeft, rel.getRight, isAnti = false, rexBuilder)
+ makeSemiAntiJoinSelectivityRexNode(
+ mq, joinInfo, rel.getLeft, rel.getRight, rel.getJoinType == JoinRelType.ANTI, rexBuilder)
}
- private def makeSemiJoinSelectivityRexNode(
+ private def makeSemiAntiJoinSelectivityRexNode(
mq: RelMetadataQuery,
joinInfo: JoinInfo,
left: RelNode,
right: RelNode,
isAnti: Boolean,
rexBuilder: RexBuilder): RexNode = {
-
val equiSelectivity: JDouble = if (!joinInfo.leftKeys.isEmpty) {
RelMdUtil.computeSemiJoinSelectivity(mq, left, right, joinInfo.leftKeys, joinInfo.rightKeys)
} else {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala
index 8472036..f6a8935 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala
@@ -17,22 +17,30 @@
*/
package org.apache.flink.table.plan.util
-import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig}
import org.apache.flink.table.calcite.{FlinkContext, FlinkPlannerImpl}
import org.apache.flink.table.{JBoolean, JByte, JDouble, JFloat, JLong, JShort}
+import com.google.common.collect.{ImmutableList, Lists}
import org.apache.calcite.config.NullCollation
-import org.apache.calcite.plan.RelOptUtil
+import org.apache.calcite.plan.RelOptUtil.InputFinder
+import org.apache.calcite.plan.{RelOptUtil, Strong}
import org.apache.calcite.rel.RelFieldCollation.{Direction, NullDirection}
+import org.apache.calcite.rel.`type`.RelDataTypeField
+import org.apache.calcite.rel.core.{Join, JoinRelType}
import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil, RexVisitorImpl}
import org.apache.calcite.sql.SqlExplainLevel
-import org.apache.calcite.sql.SqlKind.{AND, IS_FALSE, IS_TRUE, NOT, OR}
+import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.tools.RelBuilder
+import org.apache.calcite.util.mapping.Mappings
+import org.apache.calcite.util.{ImmutableBitSet, Pair, Util}
import java.io.{PrintWriter, StringWriter}
import java.math.BigDecimal
import java.sql.{Date, Time, Timestamp}
+import java.util
import java.util.Calendar
import scala.collection.JavaConversions._
@@ -132,10 +140,16 @@ object FlinkRelOptUtil {
new RelFieldCollation(fieldIndex, direction, nullDirection)
}
- def getTableConfigFromContext(rel: RelNode): TableConfig = {
+ def getTableConfigFromContext(rel: RelNode): TableConfig = {
rel.getCluster.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
}
+ /** Get max cnf node limit by context of rel */
+ def getMaxCnfNodeCount(rel: RelNode): Int = {
+ val tableConfig = getTableConfigFromContext(rel)
+ tableConfig.getConf.getInteger(PlannerConfigOptions.SQL_OPTIMIZER_CNF_NODES_LIMIT)
+ }
+
/**
* Gets values of RexLiteral
*
@@ -172,6 +186,384 @@ object FlinkRelOptUtil {
}
/**
+ * Simplifies outer joins if filter above would reject nulls.
+ *
+ * NOTES: This method should be deleted when upgrading to a new calcite version
+ * which contains CALCITE-2969.
+ *
+ * @param joinRel Join
+ * @param aboveFilters Filters from above
+ * @param joinType Join type, can not be inner join
+ */
+ def simplifyJoin(
+ joinRel: RelNode,
+ aboveFilters: ImmutableList[RexNode],
+ joinType: JoinRelType): JoinRelType = {
+ // No need to simplify if only first input output.
+ if (!joinType.projectsRight()) {
+ return joinType
+ }
+ val nTotalFields = joinRel.getRowType.getFieldCount
+ val nSysFields = 0
+ val nFieldsLeft = joinRel.getInputs.get(0).getRowType.getFieldCount
+ val nFieldsRight = joinRel.getInputs.get(1).getRowType.getFieldCount
+ assert(nTotalFields == nSysFields + nFieldsLeft + nFieldsRight)
+
+ // set the reference bitmaps for the left and right children
+ val leftBitmap = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft)
+ val rightBitmap = ImmutableBitSet.range(nSysFields + nFieldsLeft, nTotalFields)
+
+ var result = joinType
+ for (filter <- aboveFilters) {
+ if (joinType.generatesNullsOnLeft && Strong.isNotTrue(filter, leftBitmap)) {
+ result = result.cancelNullsOnLeft
+ }
+ if (joinType.generatesNullsOnRight && Strong.isNotTrue(filter, rightBitmap)) {
+ result = result.cancelNullsOnRight
+ }
+ if (joinType eq JoinRelType.INNER) {
+ return result
+ }
+ }
+ result
+ }
+
+ /**
+ * Classifies filters according to where they should be processed. They
+ * either stay where they are, are pushed to the join (if they originated
+ * from above the join), or are pushed to one of the children. Filters that
+ * are pushed are added to list passed in as input parameters.
+ *
+ * NOTES: This method should be deleted when upgrading to a new calcite version
+ * which contains CALCITE-2969.
+ *
+ * @param joinRel join node
+ * @param filters filters to be classified
+ * @param joinType join type
+ * @param pushInto whether filters can be pushed into the ON clause
+ * @param pushLeft true if filters can be pushed to the left
+ * @param pushRight true if filters can be pushed to the right
+ * @param joinFilters list of filters to push to the join
+ * @param leftFilters list of filters to push to the left child
+ * @param rightFilters list of filters to push to the right child
+ * @return whether at least one filter was pushed
+ */
+ def classifyFilters(
+ joinRel: RelNode,
+ filters: util.List[RexNode],
+ joinType: JoinRelType,
+ pushInto: Boolean,
+ pushLeft: Boolean,
+ pushRight: Boolean,
+ joinFilters: util.List[RexNode],
+ leftFilters: util.List[RexNode],
+ rightFilters: util.List[RexNode]): Boolean = {
+ val rexBuilder = joinRel.getCluster.getRexBuilder
+ val joinFields = joinRel.getRowType.getFieldList
+ val nTotalFields = joinFields.size
+ val nSysFields = 0 // joinRel.getSystemFieldList().size();
+ val leftFields = joinRel.getInputs.get(0).getRowType.getFieldList
+ val nFieldsLeft = leftFields.size
+ val rightFields = joinRel.getInputs.get(1).getRowType.getFieldList
+ val nFieldsRight = rightFields.size
+
+ assert(nTotalFields == (if (joinType.projectsRight()) {
+ nSysFields + nFieldsLeft + nFieldsRight
+ } else {
+ // SEMI/ANTI
+ nSysFields + nFieldsLeft
+ }))
+
+ // set the reference bitmaps for the left and right children
+ val leftBitmap = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft)
+ val rightBitmap = ImmutableBitSet.range(nSysFields + nFieldsLeft, nTotalFields)
+
+ val filtersToRemove = new util.ArrayList[RexNode]
+
+ filters.foreach { filter =>
+ val inputFinder = InputFinder.analyze(filter)
+ val inputBits = inputFinder.inputBitSet.build
+ // REVIEW - are there any expressions that need special handling
+ // and therefore cannot be pushed?
+ // filters can be pushed to the left child if the left child
+ // does not generate NULLs and the only columns referenced in
+ // the filter originate from the left child
+ if (pushLeft && leftBitmap.contains(inputBits)) {
+ // ignore filters that always evaluate to true
+ if (!filter.isAlwaysTrue) {
+ // adjust the field references in the filter to reflect
+ // that fields in the left now shift over by the number
+ // of system fields
+ val shiftedFilter = shiftFilter(
+ nSysFields,
+ nSysFields + nFieldsLeft,
+ -nSysFields,
+ rexBuilder,
+ joinFields,
+ nTotalFields,
+ leftFields,
+ filter)
+ leftFilters.add(shiftedFilter)
+ }
+ filtersToRemove.add(filter)
+
+ // filters can be pushed to the right child if the right child
+ // does not generate NULLs and the only columns referenced in
+ // the filter originate from the right child
+ } else if (pushRight && rightBitmap.contains(inputBits)) {
+ if (!filter.isAlwaysTrue) {
+ // that fields in the right now shift over to the left;
+ // since we never push filters to a NULL generating
+ // child, the types of the source should match the dest
+ // so we don't need to explicitly pass the destination
+ // fields to RexInputConverter
+ val shiftedFilter = shiftFilter(
+ nSysFields + nFieldsLeft,
+ nTotalFields,
+ -(nSysFields + nFieldsLeft),
+ rexBuilder,
+ joinFields,
+ nTotalFields,
+ rightFields,
+ filter)
+ rightFilters.add(shiftedFilter)
+ }
+ filtersToRemove.add(filter)
+ } else {
+ // If the filter can't be pushed to either child and the join
+ // is an inner join, push them to the join if they originated
+ // from above the join
+ if ((joinType eq JoinRelType.INNER) && pushInto) {
+ if (!joinFilters.contains(filter)) {
+ joinFilters.add(filter)
+ }
+ filtersToRemove.add(filter)
+ }
+ }
+ }
+ // Remove filters after the loop, to prevent concurrent modification.
+ if (!filtersToRemove.isEmpty) {
+ filters.removeAll(filtersToRemove)
+ }
+ // Did anything change?
+ !filtersToRemove.isEmpty
+ }
+
+ private def shiftFilter(
+ start: Int,
+ end: Int,
+ offset: Int,
+ rexBuilder: RexBuilder,
+ joinFields: util.List[RelDataTypeField],
+ nTotalFields: Int,
+ rightFields: util.List[RelDataTypeField],
+ filter: RexNode): RexNode = {
+ val adjustments = new Array[Int](nTotalFields)
+ (start until end).foreach {
+ i => adjustments(i) = offset
+ }
+ filter.accept(
+ new RelOptUtil.RexInputConverter(
+ rexBuilder,
+ joinFields,
+ rightFields,
+ adjustments)
+ )
+ }
+
+ /**
+ * Pushes down expressions in "equal" join condition.
+ *
+ * NOTES: This method should be deleted when upgrading to a new calcite version
+ * which contains CALCITE-2969.
+ *
+ * <p>For example, given
+ * "emp JOIN dept ON emp.deptno + 1 = dept.deptno", adds a project above
+ * "emp" that computes the expression
+ * "emp.deptno + 1". The resulting join condition is a simple combination
+ * of AND, equals, and input fields, plus the remaining non-equal conditions.
+ *
+ * @param originalJoin Join whose condition is to be pushed down
+ * @param relBuilder Factory to create project operator
+ */
+ def pushDownJoinConditions(originalJoin: Join, relBuilder: RelBuilder): RelNode = {
+ var joinCond: RexNode = originalJoin.getCondition
+ val joinType: JoinRelType = originalJoin.getJoinType
+
+ val extraLeftExprs: util.List[RexNode] = new util.ArrayList[RexNode]
+ val extraRightExprs: util.List[RexNode] = new util.ArrayList[RexNode]
+ val leftCount: Int = originalJoin.getLeft.getRowType.getFieldCount
+ val rightCount: Int = originalJoin.getRight.getRowType.getFieldCount
+
+ // You cannot push a 'get' because field names might change.
+ //
+ // Pushing sub-queries is OK in principle (if they don't reference both
+ // sides of the join via correlating variables) but we'd rather not do it
+ // yet.
+ if (!containsGet(joinCond) && RexUtil.SubQueryFinder.find(joinCond) == null) {
+ joinCond = pushDownEqualJoinConditions(
+ joinCond, leftCount, rightCount, extraLeftExprs, extraRightExprs)
+ }
+ relBuilder.push(originalJoin.getLeft)
+ if (!extraLeftExprs.isEmpty) {
+ val fields: util.List[RelDataTypeField] = relBuilder.peek.getRowType.getFieldList
+ val pairs: util.List[Pair[RexNode, String]] = new util.AbstractList[Pair[RexNode, String]]() {
+ override def size: Int = leftCount + extraLeftExprs.size
+
+ override def get(index: Int): Pair[RexNode, String] = if (index < leftCount) {
+ val field: RelDataTypeField = fields.get(index)
+ Pair.of(new RexInputRef(index, field.getType), field.getName)
+ }
+ else Pair.of(extraLeftExprs.get(index - leftCount), null)
+ }
+ relBuilder.project(Pair.left(pairs), Pair.right(pairs))
+ }
+
+ relBuilder.push(originalJoin.getRight)
+ if (!extraRightExprs.isEmpty) {
+ val fields: util.List[RelDataTypeField] = relBuilder.peek.getRowType.getFieldList
+ val newLeftCount: Int = leftCount + extraLeftExprs.size
+ val pairs: util.List[Pair[RexNode, String]] = new util.AbstractList[Pair[RexNode, String]]() {
+ override def size: Int = rightCount + extraRightExprs.size
+
+ override def get(index: Int): Pair[RexNode, String] = if (index < rightCount) {
+ val field: RelDataTypeField = fields.get(index)
+ Pair.of(new RexInputRef(index, field.getType), field.getName)
+ }
+ else Pair.of(RexUtil.shift(extraRightExprs.get(index - rightCount), -newLeftCount), null)
+ }
+ relBuilder.project(Pair.left(pairs), Pair.right(pairs))
+ }
+
+ val right: RelNode = relBuilder.build
+ val left: RelNode = relBuilder.build
+ relBuilder.push(originalJoin.copy(originalJoin.getTraitSet, joinCond, left, right, joinType,
+ originalJoin.isSemiJoinDone))
+
+ // handle SEMI/ANTI join here
+ var mapping: Mappings.TargetMapping = null
+ if (!originalJoin.getJoinType.projectsRight()) {
+ if (!extraLeftExprs.isEmpty) {
+ mapping = Mappings.createShiftMapping(leftCount + extraLeftExprs.size, 0, 0, leftCount)
+ }
+ } else {
+ if (!extraLeftExprs.isEmpty || !extraRightExprs.isEmpty) {
+ mapping = Mappings.createShiftMapping(
+ leftCount + extraLeftExprs.size + rightCount + extraRightExprs.size,
+ 0, 0, leftCount, leftCount, leftCount + extraLeftExprs.size, rightCount)
+ }
+ }
+
+ if (mapping != null) {
+ relBuilder.project(relBuilder.fields(mapping.inverse))
+ }
+ relBuilder.build
+ }
+
+ private def containsGet(node: RexNode) = try {
+ node.accept(new RexVisitorImpl[Void](true) {
+ override def visitCall(call: RexCall): Void = {
+ if (call.getOperator eq RexBuilder.GET_OPERATOR) {
+ throw Util.FoundOne.NULL
+ }
+ super.visitCall(call)
+ }
+ })
+ false
+ } catch {
+ case _: Util.FoundOne =>
+ true
+ }
+
+ /**
+ * Pushes down parts of a join condition.
+ *
+ * <p>For example, given
+ * "emp JOIN dept ON emp.deptno + 1 = dept.deptno", adds a project above
+ * "emp" that computes the expression
+ * "emp.deptno + 1". The resulting join condition is a simple combination
+ * of AND, equals, and input fields.
+ */
+ private def pushDownEqualJoinConditions(
+ node: RexNode,
+ leftCount: Int,
+ rightCount: Int,
+ extraLeftExprs: util.List[RexNode],
+ extraRightExprs: util.List[RexNode]): RexNode =
+ node.getKind match {
+ case AND | EQUALS =>
+ val call = node.asInstanceOf[RexCall]
+ val list = new util.ArrayList[RexNode]
+ val operands = Lists.newArrayList(call.getOperands)
+ // do not use `operands.zipWithIndex.foreach`
+ operands.indices.foreach { i =>
+ val operand = operands.get(i)
+ val left2 = leftCount + extraLeftExprs.size
+ val right2 = rightCount + extraRightExprs.size
+ val e = pushDownEqualJoinConditions(
+ operand, leftCount, rightCount, extraLeftExprs, extraRightExprs)
+ val remainingOperands = Util.skip(operands, i + 1)
+ val left3 = leftCount + extraLeftExprs.size
+ fix(remainingOperands, left2, left3)
+ fix(list, left2, left3)
+ list.add(e)
+ }
+
+ if (!(list == call.getOperands)) {
+ call.clone(call.getType, list)
+ } else {
+ call
+ }
+ case OR | INPUT_REF | LITERAL | NOT => node
+ case _ =>
+ val bits = RelOptUtil.InputFinder.bits(node)
+ val mid = leftCount + extraLeftExprs.size
+ Side.of(bits, mid) match {
+ case Side.LEFT =>
+ fix(extraRightExprs, mid, mid + 1)
+ extraLeftExprs.add(node)
+ new RexInputRef(mid, node.getType)
+ case Side.RIGHT =>
+ val index2 = mid + rightCount + extraRightExprs.size
+ extraRightExprs.add(node)
+ new RexInputRef(index2, node.getType)
+ case _ => node
+ }
+ }
+
+ private def fix(operands: util.List[RexNode], before: Int, after: Int): Unit = {
+ if (before == after) {
+ return
+ }
+ operands.indices.foreach { i =>
+ val node = operands.get(i)
+ operands.set(i, RexUtil.shift(node, before, after - before))
+ }
+ }
+
+ /**
+ * Categorizes whether a bit set contains bits left and right of a line.
+ */
+ private object Side extends Enumeration {
+ type Side = Value
+ val LEFT, RIGHT, BOTH, EMPTY = Value
+
+ private[plan] def of(bitSet: ImmutableBitSet, middle: Int): Side = {
+ val firstBit = bitSet.nextSetBit(0)
+ if (firstBit < 0) {
+ return EMPTY
+ }
+ if (firstBit >= middle) {
+ return RIGHT
+ }
+ if (bitSet.nextSetBit(middle) < 0) {
+ return LEFT
+ }
+ BOTH
+ }
+ }
+
+ /**
* Partitions the [[RexNode]] in two [[RexNode]] according to a predicate.
* The result is a pair of RexNode: the first RexNode consists of RexNode that satisfy the
* predicate and the second RexNode consists of RexNode that don't.
@@ -316,4 +708,5 @@ object FlinkRelOptUtil {
})
}
}
+
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
index e5e32e3..0be43d5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
@@ -160,9 +160,9 @@ object UpdatingPlanChecker {
.zip(joinNames.subList(lInNames.size, joinNames.length))
.toMap
- val lJoinKeys: Seq[String] = j.joinInfo.leftKeys
+ val lJoinKeys: Seq[String] = j.getJoinInfo.leftKeys
.map(lInNames.get(_))
- val rJoinKeys: Seq[String] = j.joinInfo.rightKeys
+ val rJoinKeys: Seq[String] = j.getJoinInfo.rightKeys
.map(rInNames.get(_))
.map(rInNamesToJoinNamesMap(_))
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml
index 8832f60..cabfaec 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml
@@ -84,8 +84,8 @@ LogicalProject(a=[$0], b=[$1], c=[$2], a0=[$3], b0=[$4], c0=[$5], a1=[$6], b1=[$
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(c, c1)], select=[a, b, c, a0, b0, c0, a1, b1, c1, a00, b00, c00], build=[left])
-:- Exchange(distribution=[hash[c]])
+HashJoin(joinType=[InnerJoin], where=[=(c, c1)], select=[a, b, c, a0, b0, c0, a1, b1, c1, a00, b00, c00], build=[right])
+:- Exchange(distribution=[hash[c]], exchange_mode=[BATCH])
: +- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, c, a0, b0, c0], build=[left])
: :- Exchange(distribution=[hash[a]])
: : +- Calc(select=[a, b, c], where=[>(b, 10)])
@@ -97,7 +97,7 @@ HashJoin(joinType=[InnerJoin], where=[=(c, c1)], select=[a, b, c, a0, b0, c0, a1
: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
: +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH], reuse_id=[2])
: +- Reused(reference_id=[1])
-+- Exchange(distribution=[hash[c]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[c]])
+- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, c, a0, b0, c0], build=[left])
:- Exchange(distribution=[hash[a]])
: +- Calc(select=[a, b, c], where=[<(b, 5)])
@@ -140,20 +140,16 @@ Calc(select=[a, b])
: +- Calc(select=[a], where=[=(b, 5:BIGINT)])
: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c], reuse_id=[1])
+- Exchange(distribution=[hash[a]])
- +- Calc(select=[a, b])
- +- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, a0], build=[left])
- :- Exchange(distribution=[hash[a]])
- : +- Calc(select=[a, b])
- : +- Limit(offset=[0], fetch=[10], global=[true])
- : +- Exchange(distribution=[single])
- : +- Limit(offset=[0], fetch=[10], global=[false])
- : +- Reused(reference_id=[1])
- +- Exchange(distribution=[hash[a]])
- +- HashAggregate(isMerge=[true], groupBy=[a], select=[a])
- +- Exchange(distribution=[hash[a]])
- +- LocalHashAggregate(groupBy=[a], select=[a])
- +- Calc(select=[a], where=[>(b, 5)])
- +- Reused(reference_id=[1])
+ +- HashJoin(joinType=[LeftSemiJoin], where=[=(a, a0)], select=[a, b], build=[left])
+ :- Exchange(distribution=[hash[a]])
+ : +- Calc(select=[a, b])
+ : +- Limit(offset=[0], fetch=[10], global=[true])
+ : +- Exchange(distribution=[single])
+ : +- Limit(offset=[0], fetch=[10], global=[false])
+ : +- Reused(reference_id=[1])
+ +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
+ +- Calc(select=[a], where=[>(b, 5)])
+ +- Reused(reference_id=[1])
]]>
</Resource>
</TestCase>
@@ -253,8 +249,8 @@ LogicalProject(a=[$0], b=[$1], d=[$2], e=[$3], a0=[$4], b0=[$5], d0=[$6], e0=[$7
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0, e0], build=[left])
-:- Exchange(distribution=[hash[b]])
+HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0, e0], build=[right])
+:- Exchange(distribution=[hash[b]], exchange_mode=[BATCH])
: +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, d, e], build=[left])
: :- Exchange(distribution=[hash[a]])
: : +- Calc(select=[a, b], where=[<(a, 10)])
@@ -262,7 +258,7 @@ HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0,
: +- Exchange(distribution=[hash[d]], reuse_id=[2])
: +- Calc(select=[d, e])
: +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
-+- Exchange(distribution=[hash[e]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[e]])
+- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, d, e], build=[left])
:- Exchange(distribution=[hash[a]])
: +- Calc(select=[a, b], where=[>(a, 5)])
@@ -296,8 +292,8 @@ LogicalProject(a=[$0], b=[$1], c=[$2], a0=[$3], b0=[$4], c0=[$5])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(c, c0)], select=[a, b, c, a0, b0, c0], build=[left])
-:- Exchange(distribution=[hash[c]])
+HashJoin(joinType=[InnerJoin], where=[=(c, c0)], select=[a, b, c, a0, b0, c0], build=[right])
+:- Exchange(distribution=[hash[c]], exchange_mode=[BATCH])
: +- Calc(select=[w0$o0 AS a, b, c])
: +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MAX($2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[b, c, $2, w0$o0])
: +- Sort(orderBy=[b ASC], reuse_id=[1])
@@ -307,7 +303,7 @@ HashJoin(joinType=[InnerJoin], where=[=(c, c0)], select=[a, b, c, a0, b0, c0], b
: +- Sort(orderBy=[b ASC])
: +- Exchange(distribution=[hash[b]])
: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-+- Exchange(distribution=[hash[c]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[c]])
+- Calc(select=[w0$o0 AS a, b, c])
+- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MIN($2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[b, c, $2, w0$o0])
+- Reused(reference_id=[1])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml
index c737ddb..27400aa 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml
@@ -112,7 +112,7 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5], a0=[$6], b0=[$7],
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, c, d, e, f, a0, b0, c0, d0, e0, f0], build=[left])
+HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, c, d, e, f, a0, b0, c0, d0, e0, f0], build=[right])
:- Exchange(distribution=[hash[b]])
: +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e, f], build=[left])
: :- Exchange(distribution=[hash[a]])
@@ -156,8 +156,8 @@ LogicalProject(a=[$0], b=[$1], d=[$2], e=[$3], a0=[$4], b0=[$5], d0=[$6], e0=[$7
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0, e0], build=[left])
-:- Exchange(distribution=[hash[b]])
+HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0, e0], build=[right])
+:- Exchange(distribution=[hash[b]], exchange_mode=[BATCH])
: +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, d, e], build=[left])
: :- Exchange(distribution=[hash[a]])
: : +- Calc(select=[a, b], where=[<(a, 10)])
@@ -165,7 +165,7 @@ HashJoin(joinType=[InnerJoin], where=[=(b, e0)], select=[a, b, d, e, a0, b0, d0,
: +- Exchange(distribution=[hash[d]], reuse_id=[2])
: +- Calc(select=[d, e])
: +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
-+- Exchange(distribution=[hash[e]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[e]])
+- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, d, e], build=[left])
:- Exchange(distribution=[hash[a]])
: +- Calc(select=[a, b], where=[>(a, 5)])
@@ -574,15 +574,15 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5], a0=[$6], b0=[$7],
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(a, d0)], select=[a, b, c, d, e, f, a0, b0, c0, d0, e0, f0], build=[left])
-:- Exchange(distribution=[hash[a]])
+HashJoin(joinType=[InnerJoin], where=[=(a, d0)], select=[a, b, c, d, e, f, a0, b0, c0, d0, e0, f0], build=[right])
+:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
: +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e, f], build=[left], reuse_id=[1])
: :- Exchange(distribution=[hash[a]])
: : +- Calc(select=[a, b, c], where=[LIKE(c, _UTF-16LE'He%')])
: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
: +- Exchange(distribution=[hash[d]])
: +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
-+- Exchange(distribution=[hash[d]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[d]])
+- Reused(reference_id=[1])
]]>
</Resource>
@@ -760,14 +760,14 @@ LogicalProject(a=[$0], b=[$1], EXPR$2=[$2], a0=[$3], b0=[$4], EXPR$20=[$5])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, $2, a0, b0, $20], build=[left])
-:- Exchange(distribution=[hash[a]])
+HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, $2, a0, b0, $20], build=[right])
+:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
: +- Calc(select=[a, b, w0$o0 AS $2], where=[<(b, 100)])
: +- OverAggregate(orderBy=[c DESC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0], reuse_id=[1])
: +- Sort(orderBy=[c DESC])
: +- Exchange(distribution=[single])
: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-+- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b, w0$o0 AS $2], where=[>(b, 10)])
+- Reused(reference_id=[1])
]]>
@@ -793,14 +793,14 @@ LogicalProject(a=[$0], b=[$1], EXPR$2=[$2], a0=[$3], b0=[$4], EXPR$20=[$5])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, $2, a0, b0, $20], build=[left])
-:- Exchange(distribution=[hash[a]])
+HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, $2, a0, b0, $20], build=[right])
+:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
: +- Calc(select=[a, b, w0$o0 AS $2], where=[<(b, 100)])
: +- OverAggregate(partitionBy=[c], orderBy=[c DESC], window#0=[MyFirst(c) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0], reuse_id=[1])
: +- Sort(orderBy=[c DESC])
: +- Exchange(distribution=[hash[c]])
: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-+- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b, w0$o0 AS $2], where=[>(b, 10)])
+- Reused(reference_id=[1])
]]>
@@ -828,14 +828,14 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[c, a, b, c0, a0, b0], build=[left])
-:- Exchange(distribution=[hash[a]])
+HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[c, a, b, c0, a0, b0], build=[right])
+:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
: +- Calc(select=[c, a, b], where=[>(a, 1)])
: +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_SUM(sum$0) AS a, Final_SUM(sum$1) AS b], reuse_id=[1])
: +- Exchange(distribution=[hash[c]])
: +- LocalHashAggregate(groupBy=[c], select=[c, Partial_SUM(a) AS sum$0, Partial_SUM(b) AS sum$1])
: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-+- Exchange(distribution=[hash[a]], exchange_mode=[BATCH])
++- Exchange(distribution=[hash[a]])
+- Calc(select=[c, a, b], where=[<(b, 10)])
+- Reused(reference_id=[1])
]]>
@@ -983,7 +983,7 @@ LogicalProject(a=[$0], c=[$1], c0=[$3])
<Resource name="planAfter">
<![CDATA[
Calc(select=[a, c, c0])
-+- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, c, a0, c0], build=[left])
++- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, c, a0, c0], build=[right])
:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH], reuse_id=[1])
: +- Union(all=[true], union=[a, c])
: :- Calc(select=[a, c], where=[>(b, 10)])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml
new file mode 100644
index 0000000..9d9cd80
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml
@@ -0,0 +1,1956 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testExistsAndNotExists">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE NOT EXISTS (SELECT * FROM r) AND NOT EXISTS (SELECT * FROM t WHERE l.a = t.i AND t.j < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(NOT(EXISTS({
+LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})), NOT(EXISTS({
+LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 100))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftAntiJoin], where=[=(a, i)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- NestedLoopJoin(joinType=[LeftAntiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[IS NOT NULL(m) AS $f0])
+: +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+: +- Exchange(distribution=[single])
+: +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+: +- Calc(select=[true AS i])
+: +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[i], where=[<(j, 100)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithCorrelated_AggInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT MAX(e) FROM r WHERE l.b = r.e AND d < 100 AND l.c = r.f GROUP BY d, true, f, 1)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$0=[MAX($4)])
+ LogicalProject(d=[$0], $f1=[true], f=[$2], $f3=[1], e=[$1])
+ LogicalFilter(condition=[AND(=($cor0.b, $1), <($0, 100), =($cor0.c, $2))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[f, e])
+ +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[d, e, f]])
+ +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0])
+ +- Calc(select=[d, e, f], where=[<(d, 100)])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithCorrelated_LateralTableInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r, LATERAL TABLE(table_func(f)) AS T(f1) WHERE a = d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalFilter(condition=[=($cor1.a, $0)])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalTableFunctionScan(invocation=[table_func($cor0.f)], rowType=[RecordType(VARCHAR(65536) f0)], elementType=[class [Ljava.lang.Object;])
+})], variablesSet=[[$cor1]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- Correlate(invocation=[table_func($cor0.f)], correlate=[table(table_func($cor0.f))], select=[d,e,f,f0], rowType=[RecordType(INTEGER d, BIGINT e, VARCHAR(65536) f, VARCHAR(65536) f0)], joinType=[INNER])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithCorrelated_OverInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT MAX(r.e) OVER() FROM r WHERE l.c = r.f GROUP BY r.e)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalAggregate(group=[{0}])
+ LogicalProject(e=[$1])
+ LogicalFilter(condition=[=($cor0.c, $2)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(c, f)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[f])
+ +- HashAggregate(isMerge=[true], groupBy=[e, f], select=[e, f])
+ +- Exchange(distribution=[hash[e, f]])
+ +- LocalHashAggregate(groupBy=[e, f], select=[e, f])
+ +- Calc(select=[e, f])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithCorrelated_SimpleCondition">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r WHERE a = d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithUncorrelated_ComplexCondition">
+ <Resource name="sql">
+ <![CDATA[SELECT a + 10, c FROM l WHERE b > 10 AND NOT (c like 'abc' OR NOT EXISTS (SELECT d FROM r))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(EXPR$0=[+($0, 10)], c=[$2])
++- LogicalFilter(condition=[AND(>($1, 10), NOT(OR(LIKE($2, _UTF-16LE'abc'), NOT(EXISTS({
+LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})))))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[+(a, 10) AS EXPR$0, c])
++- NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+ :- Calc(select=[a, b, c], where=[AND(>(b, 10), NOT(LIKE(c, _UTF-16LE'abc')))])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithUncorrelated_LateralTableInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r, LATERAL TABLE(table_func(f)) AS T(f1) WHERE EXISTS (SELECT * FROM t))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalFilter(condition=[EXISTS({
+LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalTableFunctionScan(invocation=[table_func($cor0.f)], rowType=[RecordType(VARCHAR(65536) f0)], elementType=[class [Ljava.lang.Object;])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[d, e, f, f0], build=[right], singleRowJoin=[true])
+ :- Correlate(invocation=[table_func($cor0.f)], correlate=[table(table_func($cor0.f))], select=[d,e,f,f0], rowType=[RecordType(INTEGER d, BIGINT e, VARCHAR(65536) f, VARCHAR(65536) f0)], joinType=[INNER])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithUncorrelated_SimpleCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithUncorrelated_SimpleCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r) AND b > 10]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(EXISTS({
+LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), >($1, 10))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+:- Calc(select=[a, b, c], where=[>(b, 10)])
+: +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testExistsWithUncorrelated_UnionInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT a FROM l WHERE EXISTS (SELECT e, f FROM r WHERE d > 10 UNION SELECT j, k FROM t WHERE i < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0])
++- LogicalFilter(condition=[EXISTS({
+LogicalUnion(all=[false])
+ LogicalProject(e=[$1], f=[$2])
+ LogicalFilter(condition=[>($0, 10)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalProject(j=[$1], k=[$2])
+ LogicalFilter(condition=[<($0, 100)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a])
++- NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c], build=[right], singleRowJoin=[true])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- HashAggregate(isMerge=[true], groupBy=[e, f], select=[e, f])
+ +- Exchange(distribution=[hash[e, f]])
+ +- LocalHashAggregate(groupBy=[e, f], select=[e, f])
+ +- Union(all=[true], union=[e, f])
+ :- Calc(select=[e, f], where=[>(d, 10)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Calc(select=[j, k], where=[<(i, 100)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInExists1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r WHERE l.a = r.d) AND a IN (SELECT i FROM t WHERE l.b = t.j)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), IN($0, {
+LogicalProject(i=[$0])
+ LogicalFilter(condition=[=($cor0.b, $1)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, i), =(b, j))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[d])
+: +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[i, j])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInExists2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b IN (SELECT j FROM t) AND EXISTS (SELECT * FROM r WHERE l.a = r.d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(IN($1, {
+LogicalProject(j=[$1])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}), EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftSemiJoin], where=[=(b, j)], select=[a, b, c], isBroadcast=[true], build=[right])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[j])
+: +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_AggInSubQuery1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b IN (SELECT MAX(r.e) FROM r WHERE l.c = r.f AND r.d < 3 GROUP BY r.f)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($1, {
+LogicalProject(EXPR$0=[$1])
+ LogicalAggregate(group=[{0}], EXPR$0=[MAX($1)])
+ LogicalProject(f=[$2], e=[$1])
+ LogicalFilter(condition=[AND(=($cor0.c, $2), <($0, 3))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, EXPR$0), =(c, f))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[EXPR$0, f])
+ +- HashAggregate(isMerge=[true], groupBy=[f], select=[f, Final_MAX(max$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[f]])
+ +- LocalHashAggregate(groupBy=[f], select=[f, Partial_MAX(e) AS max$0])
+ +- Calc(select=[f, e], where=[<(d, 3)])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_AggInSubQuery2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE (b, a) IN (SELECT COUNT(*), d FROM r WHERE l.c = r.f GROUP BY d, true, e, 1)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($1, $0, {
+LogicalProject(EXPR$0=[$4], d=[$0])
+ LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$0=[COUNT()])
+ LogicalProject(d=[$0], $f1=[true], e=[$1], $f3=[1])
+ LogicalFilter(condition=[=($cor0.c, $2)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, EXPR$0), =(a, d), =(c, f))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[EXPR$0, d, f])
+ +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_COUNT(count1$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[d, e, f]])
+ +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_COUNT(*) AS count1$0])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_ComplexCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT a + 10, c FROM l WHERE NOT(NOT(substring(c, 1, 5) IN (SELECT substring(f, 1, 5) FROM r WHERE l.b + 1 = r.e)))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(EXPR$0=[+($0, 10)], c=[$2])
++- LogicalFilter(condition=[IN(SUBSTRING($2, 1, 5), {
+LogicalProject(EXPR$0=[SUBSTRING($2, 1, 5)])
+ LogicalFilter(condition=[=(+($cor0.b, 1), $1)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[+(a, 10) AS EXPR$0, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=($f3, EXPR$0), =($f4, e))], select=[a, b, c, $f3, $f4], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, SUBSTRING(c, 1, 5) AS $f3, +(b, 1) AS $f4])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[SUBSTRING(f, 1, 5) AS EXPR$0, e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_ComplexCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b > 10 AND NOT (c like 'abc' OR a NOT IN (SELECT d FROM r WHERE l.b = r.e))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(>($1, 10), NOT(LIKE($2, _UTF-16LE'abc')), IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[=($cor0.b, $1)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b, e))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- Calc(select=[a, b, c], where=[AND(>(b, 10), NOT(LIKE(c, _UTF-16LE'abc')))])
+: +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_JoinInSubQuery1">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM l WHERE b IN (WITH r1 AS (SELECT e FROM r WHERE l.a = r.d AND r.e < 50) SELECT e FROM r1 INNER JOIN (SELECT j FROM t WHERE l.c = t.k AND i > 10) t2 ON r1.e = t2.j)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[IN($1, {
+LogicalProject(e=[$0])
+ LogicalJoin(condition=[=($0, $1)], joinType=[inner])
+ LogicalProject(e=[$1])
+ LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 50))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalProject(j=[$1])
+ LogicalFilter(condition=[AND(=($cor0.c, $2), >($0, 10))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(a, d), =(c, k))], select=[a, b, c], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[e, d, k])
+ +- HashJoin(joinType=[InnerJoin], where=[=(e, j)], select=[e, d, j, k], isBroadcast=[true], build=[left])
+ :- Exchange(distribution=[broadcast])
+ : +- Calc(select=[e, d], where=[<(e, 50)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Calc(select=[j, k], where=[>(i, 10)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_JoinInSubQuery2">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM l WHERE b IN (WITH r1 AS (SELECT e, f FROM r WHERE l.a = r.d AND r.e < 50) SELECT t.j FROM r1 LEFT JOIN t ON r1.f = t.k)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[IN($1, {
+LogicalProject(j=[$3])
+ LogicalJoin(condition=[=($1, $4)], joinType=[left])
+ LogicalProject(e=[$1], f=[$2])
+ LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 50))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, j), =(a, d))], select=[a, b, c], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[j, d])
+ +- HashJoin(joinType=[LeftOuterJoin], where=[=(f, k)], select=[f, d, j, k], isBroadcast=[true], build=[right])
+ :- Calc(select=[f, d], where=[<(e, 50)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[j, k])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_JoinInSubQuery3">
+ <Resource name="sql">
+ <![CDATA[SELECT c FROM l WHERE a IN (SELECT d FROM r RIGHT JOIN (SELECT i FROM t WHERE l.c = t.k AND i > 10) t2 ON r.d = t2.i)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalJoin(condition=[=($0, $3)], joinType=[right])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalProject(i=[$0])
+ LogicalFilter(condition=[AND(=($cor0.c, $2), >($0, 10))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(c, k))], select=[a, b, c], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, k])
+ +- HashJoin(joinType=[RightOuterJoin], where=[=(d, i)], select=[d, i, k], isBroadcast=[true], build=[left])
+ :- Exchange(distribution=[broadcast])
+ : +- Calc(select=[d])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Calc(select=[i, k], where=[>(i, 10)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_LateralTableInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE c IN (SELECT f1 FROM r, LATERAL TABLE(table_func(f)) AS T(f1) WHERE a = d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($2, {
+LogicalProject(f1=[$3])
+ LogicalFilter(condition=[=($cor1.a, $0)])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalTableFunctionScan(invocation=[table_func($cor0.f)], rowType=[RecordType(VARCHAR(65536) f0)], elementType=[class [Ljava.lang.Object;])
+})], variablesSet=[[$cor1]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(c, f1), =(a, d))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[f0 AS f1, d])
+ +- Correlate(invocation=[table_func($cor0.f)], correlate=[table(table_func($cor0.f))], select=[d,e,f,f0], rowType=[RecordType(INTEGER d, BIGINT e, VARCHAR(65536) f, VARCHAR(65536) f0)], joinType=[INNER])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_MultiFields">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE (a, SUBSTRING(c, 1, 5)) IN (SELECT d, SUBSTRING(f, 1, 5) FROM r WHERE l.b = r.e)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, SUBSTRING($2, 1, 5), {
+LogicalProject(d=[$0], EXPR$1=[SUBSTRING($2, 1, 5)])
+ LogicalFilter(condition=[=($cor0.b, $1)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a, b, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =($f3, EXPR$1), =(b, e))], select=[a, b, c, $f3], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, SUBSTRING(c, 1, 5) AS $f3])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, SUBSTRING(f, 1, 5) AS EXPR$1, e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_NonEquiCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE l.b > r.e)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[>($cor0.b, $1)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), >(b, e))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_NonEquiCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE l.b > 10)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[>($cor0.b, 10)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a, b, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c, $f3], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, CAST(IS NOT NULL(b)) AS $f3], where=[>(b, 10)])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_OverInSubQuery1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b IN (SELECT rk FROM (SELECT d, e, RANK() OVER(PARTITION BY d ORDER BY e) AS rk FROM r) t WHERE l.a <> t.d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($1, {
+LogicalProject(rk=[$2])
+ LogicalFilter(condition=[<>($cor0.a, $0)])
+ LogicalProject(d=[$0], e=[$1], rk=[RANK() OVER (PARTITION BY $0 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, rk), <>(a, d))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[w0$o0 AS rk, d])
+ +- OverAggregate(partitionBy=[d], orderBy=[e ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[d, e, f, w0$o0])
+ +- Sort(orderBy=[d ASC, e ASC])
+ +- Exchange(distribution=[hash[d]])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_ScalarQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT a FROM l WHERE (SELECT MAX(d) FROM r WHERE e IN (SELECT j FROM t)) IN (SELECT i FROM t WHERE t.k = l.c)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0])
++- LogicalFilter(condition=[IN($SCALAR_QUERY({
+LogicalAggregate(group=[{}], EXPR$0=[MAX($0)])
+ LogicalProject(d=[$0])
+ LogicalFilter(condition=[IN($1, {
+LogicalProject(j=[$1])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), {
+LogicalProject(i=[$0])
+ LogicalFilter(condition=[=($2, $cor0.c)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(EXPR$0, i), =(k, c))], select=[a, b, c, EXPR$0], isBroadcast=[true], build=[right])
+ :- NestedLoopJoin(joinType=[LeftOuterJoin], where=[true], select=[a, b, c, EXPR$0], build=[right], singleRowJoin=[true])
+ : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ : +- Exchange(distribution=[broadcast])
+ : +- SortAggregate(isMerge=[true], select=[Final_MAX(max$0) AS EXPR$0])
+ : +- Exchange(distribution=[single])
+ : +- LocalSortAggregate(select=[Partial_MAX(d) AS max$0])
+ : +- Calc(select=[d])
+ : +- HashJoin(joinType=[InnerJoin], where=[=(e, j)], select=[d, e, j], isBroadcast=[true], build=[right])
+ : :- Calc(select=[d, e])
+ : : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ : +- Exchange(distribution=[broadcast])
+ : +- HashAggregate(isMerge=[true], groupBy=[j], select=[j])
+ : +- Exchange(distribution=[hash[j]])
+ : +- LocalHashAggregate(groupBy=[j], select=[j])
+ : +- Calc(select=[j])
+ : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[i, k])
+ +- Reused(reference_id=[1])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_SimpleCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE l.b = r.e)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[=($cor0.b, $1)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b, e))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_SimpleCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b > 1 AND a IN (SELECT d FROM r WHERE l.b = r.e AND r.d > 10)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(>($1, 1), IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[AND(=($cor0.b, $1), >($0, 10))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b, e))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- Calc(select=[a, b, c], where=[>(b, 1)])
+: +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, e], where=[>(d, 10)])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithCorrelated_SimpleCondition3">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r where CAST(l.b AS INTEGER) = r.d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[=(CAST($cor0.b):INTEGER, $0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a, b, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b0, d))], select=[a, b, c, b0], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, CAST(b) AS b0])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_AggInSubQuery1">
+ <Resource name="sql">
+ <![CDATA[SELECT a FROM l WHERE b IN (SELECT MAX(e) FROM r WHERE d < 3 GROUP BY f)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0])
++- LogicalFilter(condition=[IN($1, {
+LogicalProject(EXPR$0=[$1])
+ LogicalAggregate(group=[{0}], EXPR$0=[MAX($1)])
+ LogicalProject(f=[$2], e=[$1])
+ LogicalFilter(condition=[<($0, 3)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a])
++- HashJoin(joinType=[LeftSemiJoin], where=[=(b, EXPR$0)], select=[a, b, c], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[EXPR$0])
+ +- HashAggregate(isMerge=[true], groupBy=[f], select=[f, Final_MAX(max$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[f]])
+ +- LocalHashAggregate(groupBy=[f], select=[f, Partial_MAX(e) AS max$0])
+ +- Calc(select=[f, e], where=[<(d, 3)])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_AggInSubQuery2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE (a, b) IN(SELECT d, COUNT(*) FROM r GROUP BY d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, $1, {
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
+ LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b, EXPR$1))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- HashAggregate(isMerge=[true], groupBy=[d], select=[d, Final_COUNT(count1$0) AS EXPR$1])
+ +- Exchange(distribution=[hash[d]])
+ +- LocalHashAggregate(groupBy=[d], select=[d, Partial_COUNT(*) AS count1$0])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_ComplexCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE b > 10 AND NOT (c like 'abc' OR a NOT IN (SELECT d FROM r))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(>($1, 10), NOT(LIKE($2, _UTF-16LE'abc')), IN($0, {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- Calc(select=[a, b, c], where=[AND(>(b, 10), NOT(LIKE(c, _UTF-16LE'abc')))])
+: +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_ComplexCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE (a IN (SELECT d FROM r)) IS TRUE]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IS TRUE(IN($0, {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_ComplexCondition4">
+ <Resource name="sql">
+ <![CDATA[SELECT a FROM l WHERE (SELECT MAX(e) FROM r WHERE d > 0) IN (SELECT j FROM t)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0])
++- LogicalFilter(condition=[IN($SCALAR_QUERY({
+LogicalAggregate(group=[{}], EXPR$0=[MAX($0)])
+ LogicalProject(e=[$1])
+ LogicalFilter(condition=[>($0, 0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), {
+LogicalProject(j=[$1])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a])
++- HashJoin(joinType=[LeftSemiJoin], where=[=(EXPR$0, j)], select=[a, b, c, EXPR$0], isBroadcast=[true], build=[right])
+ :- NestedLoopJoin(joinType=[LeftOuterJoin], where=[true], select=[a, b, c, EXPR$0], build=[right], singleRowJoin=[true])
+ : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ : +- Exchange(distribution=[broadcast])
+ : +- HashAggregate(isMerge=[true], select=[Final_MAX(max$0) AS EXPR$0])
+ : +- Exchange(distribution=[single])
+ : +- LocalHashAggregate(select=[Partial_MAX(e) AS max$0])
+ : +- Calc(select=[e], where=[>(d, 0)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[j])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_ComplexCondition5">
+ <Resource name="sql">
+ <![CDATA[SELECT b FROM l WHERE a IN (SELECT d FROM r WHERE e > 10) AND b > (SELECT 0.5 * SUM(j) FROM t WHERE t.i < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(b=[$1])
++- LogicalFilter(condition=[AND(IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[>($1, 10)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), >($1, $SCALAR_QUERY({
+LogicalProject(EXPR$0=[*(0.5:DECIMAL(2, 1), $0)])
+ LogicalAggregate(group=[{}], agg#0=[SUM($0)])
+ LogicalProject(j=[$1])
+ LogicalFilter(condition=[<($0, 100)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[b])
++- NestedLoopJoin(joinType=[InnerJoin], where=[>(b, $f0)], select=[b, $f0], build=[right], singleRowJoin=[true])
+ :- Calc(select=[b])
+ : +- HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+ : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ : +- Exchange(distribution=[broadcast])
+ : +- Calc(select=[d], where=[>(e, 10)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- SortAggregate(isMerge=[false], select=[SINGLE_VALUE(EXPR$0) AS $f0])
+ +- Exchange(distribution=[single])
+ +- Calc(select=[*(0.5:DECIMAL(2, 1), $f0) AS EXPR$0])
+ +- HashAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_SUM(j) AS sum$0])
+ +- Calc(select=[j], where=[<(i, 100)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_Having">
+ <Resource name="sql">
+ <![CDATA[SELECT SUM(a) AS s FROM l GROUP BY b HAVING COUNT(*) > 2 AND MAX(b) IN (SELECT e FROM r)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(s=[$1])
++- LogicalFilter(condition=[AND(>($2, 2), IN($3, {
+LogicalProject(e=[$1])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))])
+ +- LogicalAggregate(group=[{0}], s=[SUM($1)], agg#1=[COUNT()], agg#2=[MAX($0)])
+ +- LogicalProject(b=[$1], a=[$0])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[s])
++- HashJoin(joinType=[LeftSemiJoin], where=[=($f3, e)], select=[b, s, $f2, $f3], isBroadcast=[true], build=[right])
+ :- Calc(select=[b, s, $f2, $f3], where=[>($f2, 2)])
+ : +- HashAggregate(isMerge=[true], groupBy=[b], select=[b, Final_SUM(sum$0) AS s, Final_COUNT(count1$1) AS $f2, Final_MAX(max$2) AS $f3])
+ : +- Exchange(distribution=[hash[b]])
+ : +- LocalHashAggregate(groupBy=[b], select=[b, Partial_SUM(a) AS sum$0, Partial_COUNT(*) AS count1$1, Partial_MAX(b) AS max$2])
+ : +- Calc(select=[b, a])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[e])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_LateralTableInSubQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE c IN (SELECT f1 FROM r, LATERAL TABLE(table_func(f)) AS T(f1))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($2, {
+LogicalProject(f1=[$3])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalTableFunctionScan(invocation=[table_func($cor0.f)], rowType=[RecordType(VARCHAR(65536) f0)], elementType=[class [Ljava.lang.Object;])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(c, f1)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[f0 AS f1])
+ +- Correlate(invocation=[table_func($cor0.f)], correlate=[table(table_func($cor0.f))], select=[d,e,f,f0], rowType=[RecordType(INTEGER d, BIGINT e, VARCHAR(65536) f, VARCHAR(65536) f0)], joinType=[INNER])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_MultiFields">
+ <Resource name="sql">
+ <![CDATA[
+SELECT * FROM l WHERE
+ (a + 10, SUBSTRING(c, 1, 5)) IN (SELECT d + 100, SUBSTRING(f, 1, 5) FROM r)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN(+($0, 10), SUBSTRING($2, 1, 5), {
+LogicalProject(EXPR$0=[+($0, 100)], EXPR$1=[SUBSTRING($2, 1, 5)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a, b, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[AND(=($f3, EXPR$0), =($f4, EXPR$1))], select=[a, b, c, $f3, $f4], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, +(a, 10) AS $f3, SUBSTRING(c, 1, 5) AS $f4])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[+(d, 100) AS EXPR$0, SUBSTRING(f, 1, 5) AS EXPR$1])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_OverInSubQuery">
+ <Resource name="sql">
+ <![CDATA[
+SELECT * FROM l WHERE (a, b) IN
+ (SELECT MAX(r.d) OVER(), MIN(r.e) OVER(PARTITION BY f ORDER BY d) FROM r)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, $1, {
+LogicalProject(EXPR$0=[MAX($0) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$1=[MIN($1) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, $0), =(b, $1))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[w0$o0 AS $0, w1$o0 AS $1])
+ +- OverAggregate(partitionBy=[f], orderBy=[d ASC], window#0=[MIN(e) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[d, e, f, w0$o0, w1$o0])
+ +- Sort(orderBy=[f ASC, d ASC])
+ +- Exchange(distribution=[hash[f]])
+ +- OverAggregate(window#0=[MAX(d) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[d, e, f, w0$o0])
+ +- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_SimpleCondition1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_SimpleCondition2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE e < 100) AND b > 10]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[<($1, 100)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), >($1, 10))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- Calc(select=[a, b, c], where=[>(b, 10)])
+: +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d], where=[<(e, 100)])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_SimpleCondition3">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a + 1 IN (SELECT d FROM r)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN(+($0, 1), {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a, b, c])
++- HashJoin(joinType=[LeftSemiJoin], where=[=($f3, d)], select=[a, b, c, $f3], isBroadcast=[true], build=[right])
+ :- Calc(select=[a, b, c, +(a, 1) AS $f3])
+ : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInWithUncorrelated_UnionInSubQuery">
+ <Resource name="sql">
+ <![CDATA[
+SELECT a FROM l WHERE b IN
+ (SELECT e FROM r WHERE d > 10 UNION SELECT i FROM t WHERE i < 100)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0])
++- LogicalFilter(condition=[IN($1, {
+LogicalUnion(all=[false])
+ LogicalProject(e=[$1])
+ LogicalFilter(condition=[>($0, 10)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+ LogicalProject(i=[$0])
+ LogicalFilter(condition=[<($0, 100)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[a])
++- HashJoin(joinType=[LeftSemiJoin], where=[=(b, e)], select=[a, b, c], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- Exchange(distribution=[broadcast])
+ +- HashAggregate(isMerge=[true], groupBy=[e], select=[e])
+ +- Exchange(distribution=[hash[e]])
+ +- LocalHashAggregate(groupBy=[e], select=[e])
+ +- Union(all=[true], union=[e])
+ :- Calc(select=[e], where=[>(d, 10)])
+ : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Calc(select=[CAST(i) AS i], where=[<(i, 100)])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiExistsWithCorrelate2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE EXISTS (SELECT * FROM r WHERE EXISTS (SELECT * FROM t) AND l.a = r.d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[EXISTS({
+LogicalFilter(condition=[AND(EXISTS({
+LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}), =($cor0.a, $0))])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[d, e, f], build=[right], singleRowJoin=[true])
+ :- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiExistsWithUncorrelated">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l, r WHERE l.a = r.d AND EXISTS (SELECT * FROM t t1 WHERE t1.i > 50) AND b >= 1 AND EXISTS (SELECT * FROM t t2 WHERE t2.j < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5])
++- LogicalFilter(condition=[AND(=($0, $3), EXISTS({
+LogicalFilter(condition=[>($0, 50)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}), >=($1, 1), EXISTS({
+LogicalFilter(condition=[<($1, 100)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))])
+ +- LogicalJoin(condition=[true], joinType=[inner])
+ :- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+ +- LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c, d, e, f], build=[right], singleRowJoin=[true])
+:- NestedLoopJoin(joinType=[LeftSemiJoin], where=[$f0], select=[a, b, c, d, e, f], build=[right], singleRowJoin=[true])
+: :- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e, f], isBroadcast=[true], build=[left])
+: : :- Exchange(distribution=[broadcast])
+: : : +- Calc(select=[a, b, c], where=[>=(b, 1)])
+: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[IS NOT NULL(m) AS $f0])
+: +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+: +- Exchange(distribution=[single])
+: +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+: +- Calc(select=[true AS i], where=[>(i, 50)])
+: +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i], where=[<(j, 100)])
+ +- Reused(reference_id=[1])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiNotExistsWithCorrelate">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l, r WHERE l.a = r.d AND NOT EXISTS (SELECT * FROM t t1 WHERE l.b = t1.j AND t1.k > 50) AND c >= 1 AND NOT EXISTS (SELECT * FROM t t2 WHERE l.a = t2.i AND t2.j < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5])
++- LogicalFilter(condition=[AND(=($0, $3), NOT(EXISTS({
+LogicalFilter(condition=[AND(=($cor0.b, $1), >(CAST($2):BIGINT, 50))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})), >=(CAST($2):BIGINT, 1), NOT(EXISTS({
+LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 100))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})))], variablesSet=[[$cor0]])
+ +- LogicalJoin(condition=[true], joinType=[inner])
+ :- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+ +- LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftAntiJoin], where=[=(a, i)], select=[a, b, c, d, e, f], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftAntiJoin], where=[=(b, j)], select=[a, b, c, d, e, f], isBroadcast=[true], build=[right])
+: :- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e, f], isBroadcast=[true], build=[left])
+: : :- Exchange(distribution=[broadcast])
+: : : +- Calc(select=[a, b, c], where=[>=(CAST(c), 1:BIGINT)])
+: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)])
+: +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[i], where=[<(j, 100)])
+ +- Reused(reference_id=[1])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiInWithCorrelated1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE c = f) AND b IN (SELECT j FROM t WHERE a = i AND k <> 'test')]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[=($cor0.c, $2)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), IN($1, {
+LogicalProject(j=[$1])
+ LogicalFilter(condition=[AND(=($cor0.a, $0), <>($2, _UTF-16LE'test'))])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, j), =(a, i))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(c, f))], select=[a, b, c], isBroadcast=[true], build=[right])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[d, f])
+: +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[j, i], where=[<>(k, _UTF-16LE'test':VARCHAR(65536) CHARACTER SET "UTF-16LE")])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiInWithCorrelated2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT r.d FROM r WHERE l.b = r.e AND r.f IN (SELECT t.k FROM t WHERE r.e = t.j))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[AND(=($cor0.b, $1), IN($2, {
+LogicalProject(k=[$2])
+ LogicalFilter(condition=[=($cor1.e, $1)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))], variablesSet=[[$cor1]])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(a, d), =(b, e))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d, e])
+ +- HashJoin(joinType=[LeftSemiJoin], where=[AND(=(f, k), =(e, j))], select=[d, e, f], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[k, j])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiInWithCorrelated3">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r) AND b IN (SELECT j FROM t WHERE t.k = l.c)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(IN($0, {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), IN($1, {
+LogicalProject(j=[$1])
+ LogicalFilter(condition=[=($2, $cor0.c)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, j), =(k, c))], select=[a, b, c], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[d])
+: +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[j, k])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiInWithUncorrelated1">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r) AND b IN (SELECT j FROM t)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[AND(IN($0, {
+LogicalProject(d=[$0])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}), IN($1, {
+LogicalProject(j=[$1])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+}))])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(b, j)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+: :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[d])
+: +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[j])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiInWithUncorrelated2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE a IN (SELECT d FROM r WHERE e IN (SELECT j FROM t))]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[IN($0, {
+LogicalProject(d=[$0])
+ LogicalFilter(condition=[IN($1, {
+LogicalProject(j=[$1])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+})])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftSemiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
+:- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[d])
+ +- HashJoin(joinType=[LeftSemiJoin], where=[=(e, j)], select=[d, e, f], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+ +- Exchange(distribution=[broadcast])
+ +- Calc(select=[j])
+ +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultiNotExistsWithUncorrelated">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l, r WHERE l.a = r.d AND NOT EXISTS (SELECT * FROM t t1 WHERE t1.i > 50) AND b >= 1 AND NOT EXISTS (SELECT * FROM t t2 WHERE t2.j < 100)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5])
++- LogicalFilter(condition=[AND(=($0, $3), NOT(EXISTS({
+LogicalFilter(condition=[>($0, 50)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})), >=($1, 1), NOT(EXISTS({
+LogicalFilter(condition=[<($1, 100)])
+ LogicalTableScan(table=[[t, source: [TestTableSource(i, j, k)]]])
+})))])
+ +- LogicalJoin(condition=[true], joinType=[inner])
+ :- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+ +- LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+NestedLoopJoin(joinType=[LeftAntiJoin], where=[$f0], select=[a, b, c, d, e, f], build=[right], singleRowJoin=[true])
+:- NestedLoopJoin(joinType=[LeftAntiJoin], where=[$f0], select=[a, b, c, d, e, f], build=[right], singleRowJoin=[true])
+: :- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e, f], isBroadcast=[true], build=[left])
+: : :- Exchange(distribution=[broadcast])
+: : : +- Calc(select=[a, b, c], where=[>=(b, 1)])
+: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+: : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
+: +- Exchange(distribution=[broadcast])
+: +- Calc(select=[IS NOT NULL(m) AS $f0])
+: +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+: +- Exchange(distribution=[single])
+: +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+: +- Calc(select=[true AS i], where=[>(i, 50)])
+: +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1])
++- Exchange(distribution=[broadcast])
+ +- Calc(select=[IS NOT NULL(m) AS $f0])
+ +- HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS m])
+ +- Exchange(distribution=[single])
+ +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0])
+ +- Calc(select=[true AS i], where=[<(j, 100)])
+ +- Reused(reference_id=[1])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNotExistsWithCorrelated_SimpleCondition">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM l WHERE NOT EXISTS (SELECT * FROM r WHERE a = d)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2])
++- LogicalFilter(condition=[NOT(EXISTS({
+LogicalFilter(condition=[=($cor0.a, $0)])
+ LogicalTableScan(table=[[r, source: [TestTableSource(d, e, f)]]])
+}))], variablesSet=[[$cor0]])
+ +- LogicalTableScan(table=[[l, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+HashJoin(joinType=[LeftAntiJoin], where=[=(a, d)], select=[a, b, c], isBroadcast=[true], build=[right])
... 24429 lines suppressed ...