You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/27 11:26:08 UTC

[doris] branch master updated: [feature](Nereids): add MultiJoin. (#11254)

This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new a2f39278e2 [feature](Nereids): add MultiJoin. (#11254)
a2f39278e2 is described below

commit a2f39278e21cdf25816aa41938d126b62c00664f
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Jul 27 19:26:02 2022 +0800

    [feature](Nereids): add MultiJoin. (#11254)
    
    Add MultiJoin.
    
    In addtion, when (joinInputs.size() >= 3 && !conjuncts.isEmpty()), conjunct still can contains onPredicate.
    
    Like:
    ```
    A join B on A.id = B.id where A.sid = B.sid
    ```
---
 .../nereids/rules/rewrite/logical/MultiJoin.java   | 196 +++++++++++++++++++++
 .../nereids/rules/rewrite/logical/ReorderJoin.java | 168 +-----------------
 2 files changed, 199 insertions(+), 165 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
new file mode 100644
index 0000000000..b85b80cb8a
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * A MultiJoin represents a join of N inputs (NAry-Join).
+ * The regular Join represent strictly binary input (Binary-Join).
+ */
+public class MultiJoin extends PlanVisitor<Void, Void> {
+    /*
+     *        topJoin
+     *        /     \            MultiJoin
+     *   bottomJoin  C  -->     /    |    \
+     *     /    \              A     B     C
+     *    A      B
+     */
+    public final List<Plan> joinInputs = new ArrayList<>();
+    public final List<Expression> conjuncts = new ArrayList<>();
+
+    public Plan reorderJoinsAccordingToConditions() {
+        Preconditions.checkArgument(joinInputs.size() >= 2);
+        return reorderJoinsAccordingToConditions(joinInputs, conjuncts);
+    }
+
+    /**
+     * Reorder join orders according to join conditions to eliminate cross join.
+     * <p/>
+     * Let's say we have input join tables: [t1, t2, t3] and
+     * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
+     * The input join for t1 and t2 is cross join.
+     * <p/>
+     * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
+     * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
+     * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
+     * <p/>
+     * As a result, Join(t1, t3) is an inner join.
+     * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
+     */
+    private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) {
+        if (joinInputs.size() == 2) {
+            Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1));
+            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
+            List<Expression> joinConditions = split.get(true);
+            List<Expression> nonJoinConditions = split.get(false);
+
+            Optional<Expression> cond;
+            if (joinConditions.isEmpty()) {
+                cond = Optional.empty();
+            } else {
+                cond = Optional.of(ExpressionUtils.and(joinConditions));
+            }
+
+            LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, joinInputs.get(0), joinInputs.get(1));
+            if (nonJoinConditions.isEmpty()) {
+                return join;
+            } else {
+                return new LogicalFilter(ExpressionUtils.and(nonJoinConditions), join);
+            }
+        }
+        // input size >= 3;
+        Plan left = joinInputs.get(0);
+        List<Plan> candidate = joinInputs.subList(1, joinInputs.size());
+
+        List<Slot> leftOutput = left.getOutput();
+        Optional<Plan> rightOpt = candidate.stream().filter(right -> {
+            List<Slot> rightOutput = right.getOutput();
+
+            Set<Slot> joinOutput = getJoinOutput(left, right);
+            Optional<Expression> joinCond = conjuncts.stream()
+                    .filter(expr -> {
+                        Set<Slot> exprInputSlots = SlotExtractor.extractSlot(expr);
+                        if (exprInputSlots.isEmpty()) {
+                            return false;
+                        }
+
+                        if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
+                            return false;
+                        }
+
+                        if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
+                            return false;
+                        }
+
+                        return joinOutput.containsAll(exprInputSlots);
+                    }).findFirst();
+            return joinCond.isPresent();
+        }).findFirst();
+
+        Plan right = rightOpt.orElseGet(() -> candidate.get(1));
+        Set<Slot> joinOutput = getJoinOutput(left, right);
+        Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
+        List<Expression> joinConditions = split.get(true);
+        List<Expression> nonJoinConditions = split.get(false);
+
+        Optional<Expression> cond;
+        if (joinConditions.isEmpty()) {
+            cond = Optional.empty();
+        } else {
+            cond = Optional.of(ExpressionUtils.and(joinConditions));
+        }
+
+        LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, left, right);
+
+        List<Plan> newInputs = new ArrayList<>();
+        newInputs.add(join);
+        newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
+        return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
+    }
+
+    private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> conjuncts, Set<Slot> slots) {
+        return conjuncts.stream().collect(Collectors.partitioningBy(
+                // TODO: support non equal to conditions.
+                expr -> expr instanceof EqualTo && slots.containsAll(SlotExtractor.extractSlot(expr))));
+    }
+
+    private Set<Slot> getJoinOutput(Plan left, Plan right) {
+        HashSet<Slot> joinOutput = new HashSet<>();
+        joinOutput.addAll(left.getOutput());
+        joinOutput.addAll(right.getOutput());
+        return joinOutput;
+    }
+
+    @Override
+    public Void visit(Plan plan, Void context) {
+        for (Plan child : plan.children()) {
+            child.accept(this, context);
+        }
+        return null;
+    }
+
+    @Override
+    public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
+        Plan child = filter.child();
+        if (child instanceof LogicalJoin) {
+            conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates()));
+        }
+
+        child.accept(this, context);
+        return null;
+    }
+
+    @Override
+    public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) {
+        if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
+            return null;
+        }
+
+        join.left().accept(this, context);
+        join.right().accept(this, context);
+
+        join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond)));
+        if (!(join.left() instanceof LogicalJoin)) {
+            joinInputs.add(join.left());
+        }
+        if (!(join.right() instanceof LogicalJoin)) {
+            joinInputs.add(join.right());
+        }
+        return null;
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
index 7fafe2ec90..c79e16a5de 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
@@ -20,24 +20,9 @@ package org.apache.doris.nereids.rules.rewrite.logical;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
-import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
-import org.apache.doris.nereids.util.ExpressionUtils;
-
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
 
 /**
  * Try to eliminate cross join via finding join conditions in filters and change the join orders.
@@ -64,157 +49,10 @@ public class ReorderJoin extends OneRewriteRuleFactory {
                     .isEnableNereidsReorderToEliminateCrossJoin()) {
                 return filter;
             }
-            PlanCollector collector = new PlanCollector();
-            filter.accept(collector, null);
-            List<Plan> joinInputs = collector.joinInputs;
-            List<Expression> conjuncts = collector.conjuncts;
+            MultiJoin multiJoin = new MultiJoin();
+            filter.accept(multiJoin, null);
 
-            if (joinInputs.size() >= 3 && !conjuncts.isEmpty()) {
-                return reorderJoinsAccordingToConditions(joinInputs, conjuncts);
-            } else {
-                return filter;
-            }
+            return multiJoin.reorderJoinsAccordingToConditions();
         }).toRule(RuleType.REORDER_JOIN);
     }
-
-    /**
-     * Reorder join orders according to join conditions to eliminate cross join.
-     * <p/>
-     * Let's say we have input join tables: [t1, t2, t3] and
-     * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
-     * The input join for t1 and t2 is cross join.
-     * <p/>
-     * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
-     * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
-     * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
-     * <p/>
-     * As a result, Join(t1, t3) is an inner join.
-     * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
-     */
-    private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) {
-        if (joinInputs.size() == 2) {
-            Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1));
-            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
-            List<Expression> joinConditions = split.get(true);
-            List<Expression> nonJoinConditions = split.get(false);
-
-            Optional<Expression> cond;
-            if (joinConditions.isEmpty()) {
-                cond = Optional.empty();
-            } else {
-                cond = Optional.of(ExpressionUtils.and(joinConditions));
-            }
-
-            LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, joinInputs.get(0), joinInputs.get(1));
-            if (nonJoinConditions.isEmpty()) {
-                return join;
-            } else {
-                return new LogicalFilter(ExpressionUtils.and(nonJoinConditions), join);
-            }
-        } else {
-            Plan left = joinInputs.get(0);
-            List<Plan> candidate = joinInputs.subList(1, joinInputs.size());
-
-            List<Slot> leftOutput = left.getOutput();
-            Optional<Plan> rightOpt = candidate.stream().filter(right -> {
-                List<Slot> rightOutput = right.getOutput();
-
-                Set<Slot> joinOutput = getJoinOutput(left, right);
-                Optional<Expression> joinCond = conjuncts.stream()
-                        .filter(expr -> {
-                            Set<Slot> exprInputSlots = SlotExtractor.extractSlot(expr);
-                            if (exprInputSlots.isEmpty()) {
-                                return false;
-                            }
-
-                            if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
-                                return false;
-                            }
-
-                            if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
-                                return false;
-                            }
-
-                            return joinOutput.containsAll(exprInputSlots);
-                        }).findFirst();
-                return joinCond.isPresent();
-            }).findFirst();
-
-            Plan right = rightOpt.orElseGet(() -> candidate.get(1));
-            Set<Slot> joinOutput = getJoinOutput(left, right);
-            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
-            List<Expression> joinConditions = split.get(true);
-            List<Expression> nonJoinConditions = split.get(false);
-
-            Optional<Expression> cond;
-            if (joinConditions.isEmpty()) {
-                cond = Optional.empty();
-            } else {
-                cond = Optional.of(ExpressionUtils.and(joinConditions));
-            }
-
-            LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, left, right);
-
-            List<Plan> newInputs = new ArrayList<>();
-            newInputs.add(join);
-            newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
-            return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
-        }
-    }
-
-    private Set<Slot> getJoinOutput(Plan left, Plan right) {
-        HashSet<Slot> joinOutput = new HashSet<>();
-        joinOutput.addAll(left.getOutput());
-        joinOutput.addAll(right.getOutput());
-        return joinOutput;
-    }
-
-    private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> conjuncts, Set<Slot> slots) {
-        return conjuncts.stream().collect(Collectors.partitioningBy(
-                // TODO: support non equal to conditions.
-                expr -> expr instanceof EqualTo && slots.containsAll(SlotExtractor.extractSlot(expr))));
-    }
-
-    private class PlanCollector extends PlanVisitor<Void, Void> {
-        public final List<Plan> joinInputs = new ArrayList<>();
-        public final List<Expression> conjuncts = new ArrayList<>();
-
-        @Override
-        public Void visit(Plan plan, Void context) {
-            for (Plan child : plan.children()) {
-                child.accept(this, context);
-            }
-            return null;
-        }
-
-        @Override
-        public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
-            Plan child = filter.child();
-            if (child instanceof LogicalJoin) {
-                conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates()));
-            }
-
-            child.accept(this, context);
-            return null;
-        }
-
-        @Override
-        public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) {
-            if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
-                return null;
-            }
-
-            join.left().accept(this, context);
-            join.right().accept(this, context);
-
-            join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond)));
-            if (!(join.left() instanceof LogicalJoin)) {
-                joinInputs.add(join.left());
-            }
-            if (!(join.right() instanceof LogicalJoin)) {
-                joinInputs.add(join.right());
-            }
-            return null;
-        }
-    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org