You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ja...@apache.org on 2022/10/24 09:11:50 UTC

[doris] branch master updated: [improve](Nereids): ReorderJoin eliminate this recursion (#13505)

This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 409bd76999 [improve](Nereids): ReorderJoin eliminate this recursion (#13505)
409bd76999 is described below

commit 409bd76999236e9e624e6f2b1f3428191e86ea57
Author: jakevin <ja...@gmail.com>
AuthorDate: Mon Oct 24 17:11:43 2022 +0800

    [improve](Nereids): ReorderJoin eliminate this recursion (#13505)
---
 .../nereids/rules/rewrite/logical/ReorderJoin.java | 82 ++++++++++++----------
 .../rules/rewrite/logical/ReorderJoinTest.java     | 33 ++++++++-
 .../doris/nereids/sqltest/MultiJoinTest.java       | 39 +++++++++-
 .../org/apache/doris/nereids/util/PlanChecker.java |  1 +
 4 files changed, 112 insertions(+), 43 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
index 1cbdc370e2..c0c8622348 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
@@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.Lists;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -79,8 +80,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
 
     /**
      * Recursively convert to
-     * {@link LogicalJoin} or
-     * {@link LogicalFilter}--{@link LogicalJoin}
+     * {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
      * --> {@link MultiJoin}
      */
     public Plan joinToMultiJoin(Plan plan) {
@@ -182,20 +182,20 @@ public class ReorderJoin extends OneRewriteRuleFactory {
      * <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
      * </ul>
      * </p>
-     * <p>
      * Graphic presentation:
+     * <pre>
      * A JOIN B JOIN C LEFT JOIN D JOIN F
      *      left                  left│
      * A  B  C  D  F   ──►   A  B  C  │ D  F   ──►  MJ(LOJ A,B,C,MJ(DF)
-     * <p>
+     *
      * A JOIN B RIGHT JOIN C JOIN D JOIN F
      *     right                  │right
      * A  B  C  D  F   ──►   A  B │  C  D  F   ──►  MJ(A,B,MJ(ROJ C,D,F)
-     * <p>
+     *
      * (A JOIN B JOIN C) FULL JOIN (D JOIN F)
      *       full                    │
      * A  B  C  D  F   ──►   A  B  C │ D  F    ──►  MJ(FOJ MJ(A,B,C) MJ(D,F))
-     * </p>
+     * </pre>
      */
     public Plan multiJoinToJoin(MultiJoin multiJoin) {
         if (multiJoin.arity() == 1) {
@@ -272,24 +272,22 @@ public class ReorderJoin extends OneRewriteRuleFactory {
         }
 
         // following this multiJoin just contain INNER/CROSS.
-        List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
+        Set<Expression> joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter());
 
         Plan left = multiJoinHandleChildren.child(0);
-        List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
-
-        LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
-        List<Plan> newInputs = Lists.newArrayList();
-        newInputs.add(join);
-        newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
-
-        joinFilter.removeAll(join.getHashJoinConjuncts());
-        joinFilter.removeAll(join.getOtherJoinConjuncts());
-        // TODO(wj): eliminate this recursion.
-        return multiJoinToJoin(new MultiJoin(
-                newInputs,
-                joinFilter,
-                JoinType.INNER_JOIN,
-                ExpressionUtils.EMPTY_CONDITION));
+        Set<Integer> usedPlansIndex = new HashSet<>();
+        usedPlansIndex.add(0);
+
+        while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
+            LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
+                    joinFilter, usedPlansIndex);
+            join.getHashJoinConjuncts().forEach(joinFilter::remove);
+            join.getOtherJoinConjuncts().forEach(joinFilter::remove);
+
+            left = join;
+        }
+
+        return PlanUtils.filterOrSelf(new ArrayList<>(joinFilter), left);
     }
 
     /**
@@ -319,9 +317,14 @@ public class ReorderJoin extends OneRewriteRuleFactory {
      * @return InnerJoin or CrossJoin{left, last of [candidates]}
      */
     private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
-            List<Expression> joinFilter) {
+            Set<Expression> joinFilter, Set<Integer> usedPlansIndex) {
+        List<Expression> otherJoinConditions = Lists.newArrayList();
         Set<Slot> leftOutputSet = left.getOutputSet();
         for (int i = 0; i < candidates.size(); i++) {
+            if (usedPlansIndex.contains(i)) {
+                continue;
+            }
+
             Plan candidate = candidates.get(i);
             Set<Slot> rightOutputSet = candidate.getOutputSet();
 
@@ -330,34 +333,35 @@ public class ReorderJoin extends OneRewriteRuleFactory {
             List<Expression> currentJoinFilter = joinFilter.stream()
                     .filter(expr -> {
                         Set<Slot> exprInputSlots = expr.getInputSlots();
-                        Preconditions.checkState(exprInputSlots.size() > 1,
-                                "Predicate like table.col > 1 must have pushdown.");
-                        if (leftOutputSet.containsAll(exprInputSlots)) {
-                            return false;
-                        }
-                        if (rightOutputSet.containsAll(exprInputSlots)) {
-                            return false;
-                        }
-
-                        return joinOutput.containsAll(exprInputSlots);
+                        return !leftOutputSet.containsAll(exprInputSlots)
+                                && !rightOutputSet.containsAll(exprInputSlots)
+                                && joinOutput.containsAll(exprInputSlots);
                     }).collect(Collectors.toList());
 
             Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
                     left.getOutput(), candidate.getOutput(), currentJoinFilter);
             List<Expression> hashJoinConditions = pair.first;
-            List<Expression> otherJoinConditions = pair.second;
+            otherJoinConditions = pair.second;
             if (!hashJoinConditions.isEmpty()) {
+                usedPlansIndex.add(i);
                 return new LogicalJoin<>(JoinType.INNER_JOIN,
                         hashJoinConditions, otherJoinConditions,
                         left, candidate);
             }
-
-            if (i == candidates.size() - 1) {
-                return new LogicalJoin<>(JoinType.CROSS_JOIN,
-                        hashJoinConditions, otherJoinConditions,
-                        left, candidate);
+        }
+        // All { left -> one in [candidates] } is CrossJoin
+        // Generate a CrossJoin
+        for (int j = candidates.size() - 1; j >= 0; j--) {
+            if (usedPlansIndex.contains(j)) {
+                continue;
             }
+            usedPlansIndex.add(j);
+            return new LogicalJoin<>(JoinType.CROSS_JOIN,
+                    ExpressionUtils.EMPTY_CONDITION,
+                    otherJoinConditions,
+                    left, candidates.get(j));
         }
+
         throw new RuntimeException("findInnerJoin: can't reach here");
     }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
index ffb7e16510..da386f8911 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
@@ -130,17 +130,44 @@ class ReorderJoinTest implements PatternMatchSupported {
         check(plans);
     }
 
-    public void check(List<LogicalPlan> plans) {
+    @Test
+    public void testCrossJoin() {
+        ImmutableList<LogicalPlan> plans = ImmutableList.of(
+                new LogicalPlanBuilder(scan1)
+                        .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
+                        .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+                        .filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
+                        .build(),
+                new LogicalPlanBuilder(scan1)
+                        .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
+                        .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+                        .filter(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)))
+                        .build()
+        );
+
         for (LogicalPlan plan : plans) {
             PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
                     .applyBottomUp(new ReorderJoin())
+                    .matchesFromRoot(
+                            logicalJoin(
+                                    logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+                                    leafPlan()
+                            ).when(join -> join.getJoinType().isCrossJoin())
+                    );
+        }
+    }
+
+    public void check(List<LogicalPlan> plans) {
+        for (LogicalPlan plan : plans) {
+            PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                    .rewrite()
+                    .printlnTree()
                     .matchesFromRoot(
                             logicalJoin(
                                     logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
                                     leafPlan()
                             ).whenNot(join -> join.getJoinType().isCrossJoin())
-                    )
-                    .printlnTree();
+                    );
         }
     }
 
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
index 230c9cc245..5beb12445c 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
 import org.apache.doris.nereids.util.PlanChecker;
 
 import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 
 import java.util.List;
@@ -29,8 +30,9 @@ public class MultiJoinTest extends SqlTestBase {
     @Test
     void testMultiJoinEliminateCross() {
         List<String> sqls = ImmutableList.<String>builder()
-                .add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
                 .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
+                .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0")
+                .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0 AND T1.id + T2.id + T3.id > 0")
                 .build();
 
         for (String sql : sqls) {
@@ -47,6 +49,41 @@ public class MultiJoinTest extends SqlTestBase {
         }
     }
 
+    @Test
+    @Disabled
+    // TODO: MultiJoin And EliminateOuter
+    void testEliminateBelowOuter() {
+        String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .applyBottomUp(new ReorderJoin())
+                .printlnTree();
+    }
+
+    @Test
+    void testPushdownAndEliminateOuter() {
+        String sql = "SELECT * FROM T1 LEFT JOIN T2 ON T1.id = T2.id WHERE T2.score > 0";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .printlnTree()
+                .matches(
+                        logicalJoin().when(join -> join.getJoinType().isInnerJoin())
+                );
+
+        String sql1 = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id AND T3.score > 0";
+        PlanChecker.from(connectContext)
+                .analyze(sql1)
+                .rewrite()
+                .printlnTree()
+                .matches(
+                        logicalJoin(
+                                logicalJoin().when(join -> join.getJoinType().isInnerJoin()),
+                                any()
+                        ).when(join -> join.getJoinType().isInnerJoin())
+                );
+    }
+
     @Test
     void testMultiJoinExistCross() {
         List<String> sqls = ImmutableList.<String>builder()
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index ee364c7854..3ef68ead77 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -363,6 +363,7 @@ public class PlanChecker {
 
     public PlanChecker printlnTree() {
         System.out.println(cascadesContext.getMemo().copyOut().treeString());
+        System.out.println("-----------------------------");
         return this;
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org