You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ja...@apache.org on 2022/10/24 09:11:50 UTC
[doris] branch master updated: [improve](Nereids): ReorderJoin eliminate this recursion (#13505)
This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 409bd76999 [improve](Nereids): ReorderJoin eliminate this recursion (#13505)
409bd76999 is described below
commit 409bd76999236e9e624e6f2b1f3428191e86ea57
Author: jakevin <ja...@gmail.com>
AuthorDate: Mon Oct 24 17:11:43 2022 +0800
[improve](Nereids): ReorderJoin eliminate this recursion (#13505)
---
.../nereids/rules/rewrite/logical/ReorderJoin.java | 82 ++++++++++++----------
.../rules/rewrite/logical/ReorderJoinTest.java | 33 ++++++++-
.../doris/nereids/sqltest/MultiJoinTest.java | 39 +++++++++-
.../org/apache/doris/nereids/util/PlanChecker.java | 1 +
4 files changed, 112 insertions(+), 43 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
index 1cbdc370e2..c0c8622348 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
@@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
+import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -79,8 +80,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
/**
* Recursively convert to
- * {@link LogicalJoin} or
- * {@link LogicalFilter}--{@link LogicalJoin}
+ * {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
* --> {@link MultiJoin}
*/
public Plan joinToMultiJoin(Plan plan) {
@@ -182,20 +182,20 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
* </ul>
* </p>
- * <p>
* Graphic presentation:
+ * <pre>
* A JOIN B JOIN C LEFT JOIN D JOIN F
* left left│
* A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF)
- * <p>
+ *
* A JOIN B RIGHT JOIN C JOIN D JOIN F
* right │right
* A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F)
- * <p>
+ *
* (A JOIN B JOIN C) FULL JOIN (D JOIN F)
* full │
* A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
- * </p>
+ * </pre>
*/
public Plan multiJoinToJoin(MultiJoin multiJoin) {
if (multiJoin.arity() == 1) {
@@ -272,24 +272,22 @@ public class ReorderJoin extends OneRewriteRuleFactory {
}
// following this multiJoin just contain INNER/CROSS.
- List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
+ Set<Expression> joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter());
Plan left = multiJoinHandleChildren.child(0);
- List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
-
- LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
- List<Plan> newInputs = Lists.newArrayList();
- newInputs.add(join);
- newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
-
- joinFilter.removeAll(join.getHashJoinConjuncts());
- joinFilter.removeAll(join.getOtherJoinConjuncts());
- // TODO(wj): eliminate this recursion.
- return multiJoinToJoin(new MultiJoin(
- newInputs,
- joinFilter,
- JoinType.INNER_JOIN,
- ExpressionUtils.EMPTY_CONDITION));
+ Set<Integer> usedPlansIndex = new HashSet<>();
+ usedPlansIndex.add(0);
+
+ while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
+ LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
+ joinFilter, usedPlansIndex);
+ join.getHashJoinConjuncts().forEach(joinFilter::remove);
+ join.getOtherJoinConjuncts().forEach(joinFilter::remove);
+
+ left = join;
+ }
+
+ return PlanUtils.filterOrSelf(new ArrayList<>(joinFilter), left);
}
/**
@@ -319,9 +317,14 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* @return InnerJoin or CrossJoin{left, last of [candidates]}
*/
private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
- List<Expression> joinFilter) {
+ Set<Expression> joinFilter, Set<Integer> usedPlansIndex) {
+ List<Expression> otherJoinConditions = Lists.newArrayList();
Set<Slot> leftOutputSet = left.getOutputSet();
for (int i = 0; i < candidates.size(); i++) {
+ if (usedPlansIndex.contains(i)) {
+ continue;
+ }
+
Plan candidate = candidates.get(i);
Set<Slot> rightOutputSet = candidate.getOutputSet();
@@ -330,34 +333,35 @@ public class ReorderJoin extends OneRewriteRuleFactory {
List<Expression> currentJoinFilter = joinFilter.stream()
.filter(expr -> {
Set<Slot> exprInputSlots = expr.getInputSlots();
- Preconditions.checkState(exprInputSlots.size() > 1,
- "Predicate like table.col > 1 must have pushdown.");
- if (leftOutputSet.containsAll(exprInputSlots)) {
- return false;
- }
- if (rightOutputSet.containsAll(exprInputSlots)) {
- return false;
- }
-
- return joinOutput.containsAll(exprInputSlots);
+ return !leftOutputSet.containsAll(exprInputSlots)
+ && !rightOutputSet.containsAll(exprInputSlots)
+ && joinOutput.containsAll(exprInputSlots);
}).collect(Collectors.toList());
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
left.getOutput(), candidate.getOutput(), currentJoinFilter);
List<Expression> hashJoinConditions = pair.first;
- List<Expression> otherJoinConditions = pair.second;
+ otherJoinConditions = pair.second;
if (!hashJoinConditions.isEmpty()) {
+ usedPlansIndex.add(i);
return new LogicalJoin<>(JoinType.INNER_JOIN,
hashJoinConditions, otherJoinConditions,
left, candidate);
}
-
- if (i == candidates.size() - 1) {
- return new LogicalJoin<>(JoinType.CROSS_JOIN,
- hashJoinConditions, otherJoinConditions,
- left, candidate);
+ }
+ // All { left -> one in [candidates] } is CrossJoin
+ // Generate a CrossJoin
+ for (int j = candidates.size() - 1; j >= 0; j--) {
+ if (usedPlansIndex.contains(j)) {
+ continue;
}
+ usedPlansIndex.add(j);
+ return new LogicalJoin<>(JoinType.CROSS_JOIN,
+ ExpressionUtils.EMPTY_CONDITION,
+ otherJoinConditions,
+ left, candidates.get(j));
}
+
throw new RuntimeException("findInnerJoin: can't reach here");
}
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
index ffb7e16510..da386f8911 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
@@ -130,17 +130,44 @@ class ReorderJoinTest implements PatternMatchSupported {
check(plans);
}
- public void check(List<LogicalPlan> plans) {
+ @Test
+ public void testCrossJoin() {
+ ImmutableList<LogicalPlan> plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)))
+ .build()
+ );
+
for (LogicalPlan plan : plans) {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyBottomUp(new ReorderJoin())
+ .matchesFromRoot(
+ logicalJoin(
+ logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).when(join -> join.getJoinType().isCrossJoin())
+ );
+ }
+ }
+
+ public void check(List<LogicalPlan> plans) {
+ for (LogicalPlan plan : plans) {
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .rewrite()
+ .printlnTree()
.matchesFromRoot(
logicalJoin(
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
leafPlan()
).whenNot(join -> join.getJoinType().isCrossJoin())
- )
- .printlnTree();
+ );
}
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
index 230c9cc245..5beb12445c 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import org.apache.doris.nereids.util.PlanChecker;
import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.List;
@@ -29,8 +30,9 @@ public class MultiJoinTest extends SqlTestBase {
@Test
void testMultiJoinEliminateCross() {
List<String> sqls = ImmutableList.<String>builder()
- .add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
+ .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0")
+ .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0 AND T1.id + T2.id + T3.id > 0")
.build();
for (String sql : sqls) {
@@ -47,6 +49,41 @@ public class MultiJoinTest extends SqlTestBase {
}
}
+ @Test
+ @Disabled
+ // TODO: MultiJoin And EliminateOuter
+ void testEliminateBelowOuter() {
+ String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .applyBottomUp(new ReorderJoin())
+ .printlnTree();
+ }
+
+ @Test
+ void testPushdownAndEliminateOuter() {
+ String sql = "SELECT * FROM T1 LEFT JOIN T2 ON T1.id = T2.id WHERE T2.score > 0";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .printlnTree()
+ .matches(
+ logicalJoin().when(join -> join.getJoinType().isInnerJoin())
+ );
+
+ String sql1 = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id AND T3.score > 0";
+ PlanChecker.from(connectContext)
+ .analyze(sql1)
+ .rewrite()
+ .printlnTree()
+ .matches(
+ logicalJoin(
+ logicalJoin().when(join -> join.getJoinType().isInnerJoin()),
+ any()
+ ).when(join -> join.getJoinType().isInnerJoin())
+ );
+ }
+
@Test
void testMultiJoinExistCross() {
List<String> sqls = ImmutableList.<String>builder()
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index ee364c7854..3ef68ead77 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -363,6 +363,7 @@ public class PlanChecker {
public PlanChecker printlnTree() {
System.out.println(cascadesContext.getMemo().copyOut().treeString());
+ System.out.println("-----------------------------");
return this;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org