You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/27 11:26:08 UTC
[doris] branch master updated: [feature](Nereids): add MultiJoin. (#11254)
This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new a2f39278e2 [feature](Nereids): add MultiJoin. (#11254)
a2f39278e2 is described below
commit a2f39278e21cdf25816aa41938d126b62c00664f
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Jul 27 19:26:02 2022 +0800
[feature](Nereids): add MultiJoin. (#11254)
Add MultiJoin.
In addtion, when (joinInputs.size() >= 3 && !conjuncts.isEmpty()), conjunct still can contains onPredicate.
Like:
```
A join B on A.id = B.id where A.sid = B.sid
```
---
.../nereids/rules/rewrite/logical/MultiJoin.java | 196 +++++++++++++++++++++
.../nereids/rules/rewrite/logical/ReorderJoin.java | 168 +-----------------
2 files changed, 199 insertions(+), 165 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
new file mode 100644
index 0000000000..b85b80cb8a
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * A MultiJoin represents a join of N inputs (NAry-Join).
+ * The regular Join represent strictly binary input (Binary-Join).
+ */
+public class MultiJoin extends PlanVisitor<Void, Void> {
+ /*
+ * topJoin
+ * / \ MultiJoin
+ * bottomJoin C --> / | \
+ * / \ A B C
+ * A B
+ */
+ public final List<Plan> joinInputs = new ArrayList<>();
+ public final List<Expression> conjuncts = new ArrayList<>();
+
+ public Plan reorderJoinsAccordingToConditions() {
+ Preconditions.checkArgument(joinInputs.size() >= 2);
+ return reorderJoinsAccordingToConditions(joinInputs, conjuncts);
+ }
+
+ /**
+ * Reorder join orders according to join conditions to eliminate cross join.
+ * <p/>
+ * Let's say we have input join tables: [t1, t2, t3] and
+ * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
+ * The input join for t1 and t2 is cross join.
+ * <p/>
+ * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
+ * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
+ * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
+ * <p/>
+ * As a result, Join(t1, t3) is an inner join.
+ * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
+ */
+ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) {
+ if (joinInputs.size() == 2) {
+ Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1));
+ Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
+ List<Expression> joinConditions = split.get(true);
+ List<Expression> nonJoinConditions = split.get(false);
+
+ Optional<Expression> cond;
+ if (joinConditions.isEmpty()) {
+ cond = Optional.empty();
+ } else {
+ cond = Optional.of(ExpressionUtils.and(joinConditions));
+ }
+
+ LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, joinInputs.get(0), joinInputs.get(1));
+ if (nonJoinConditions.isEmpty()) {
+ return join;
+ } else {
+ return new LogicalFilter(ExpressionUtils.and(nonJoinConditions), join);
+ }
+ }
+ // input size >= 3;
+ Plan left = joinInputs.get(0);
+ List<Plan> candidate = joinInputs.subList(1, joinInputs.size());
+
+ List<Slot> leftOutput = left.getOutput();
+ Optional<Plan> rightOpt = candidate.stream().filter(right -> {
+ List<Slot> rightOutput = right.getOutput();
+
+ Set<Slot> joinOutput = getJoinOutput(left, right);
+ Optional<Expression> joinCond = conjuncts.stream()
+ .filter(expr -> {
+ Set<Slot> exprInputSlots = SlotExtractor.extractSlot(expr);
+ if (exprInputSlots.isEmpty()) {
+ return false;
+ }
+
+ if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
+ return false;
+ }
+
+ if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
+ return false;
+ }
+
+ return joinOutput.containsAll(exprInputSlots);
+ }).findFirst();
+ return joinCond.isPresent();
+ }).findFirst();
+
+ Plan right = rightOpt.orElseGet(() -> candidate.get(1));
+ Set<Slot> joinOutput = getJoinOutput(left, right);
+ Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
+ List<Expression> joinConditions = split.get(true);
+ List<Expression> nonJoinConditions = split.get(false);
+
+ Optional<Expression> cond;
+ if (joinConditions.isEmpty()) {
+ cond = Optional.empty();
+ } else {
+ cond = Optional.of(ExpressionUtils.and(joinConditions));
+ }
+
+ LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, left, right);
+
+ List<Plan> newInputs = new ArrayList<>();
+ newInputs.add(join);
+ newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
+ return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
+ }
+
+ private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> conjuncts, Set<Slot> slots) {
+ return conjuncts.stream().collect(Collectors.partitioningBy(
+ // TODO: support non equal to conditions.
+ expr -> expr instanceof EqualTo && slots.containsAll(SlotExtractor.extractSlot(expr))));
+ }
+
+ private Set<Slot> getJoinOutput(Plan left, Plan right) {
+ HashSet<Slot> joinOutput = new HashSet<>();
+ joinOutput.addAll(left.getOutput());
+ joinOutput.addAll(right.getOutput());
+ return joinOutput;
+ }
+
+ @Override
+ public Void visit(Plan plan, Void context) {
+ for (Plan child : plan.children()) {
+ child.accept(this, context);
+ }
+ return null;
+ }
+
+ @Override
+ public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
+ Plan child = filter.child();
+ if (child instanceof LogicalJoin) {
+ conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates()));
+ }
+
+ child.accept(this, context);
+ return null;
+ }
+
+ @Override
+ public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) {
+ if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
+ return null;
+ }
+
+ join.left().accept(this, context);
+ join.right().accept(this, context);
+
+ join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond)));
+ if (!(join.left() instanceof LogicalJoin)) {
+ joinInputs.add(join.left());
+ }
+ if (!(join.right() instanceof LogicalJoin)) {
+ joinInputs.add(join.right());
+ }
+ return null;
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
index 7fafe2ec90..c79e16a5de 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
@@ -20,24 +20,9 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
-import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
-import org.apache.doris.nereids.util.ExpressionUtils;
-
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Try to eliminate cross join via finding join conditions in filters and change the join orders.
@@ -64,157 +49,10 @@ public class ReorderJoin extends OneRewriteRuleFactory {
.isEnableNereidsReorderToEliminateCrossJoin()) {
return filter;
}
- PlanCollector collector = new PlanCollector();
- filter.accept(collector, null);
- List<Plan> joinInputs = collector.joinInputs;
- List<Expression> conjuncts = collector.conjuncts;
+ MultiJoin multiJoin = new MultiJoin();
+ filter.accept(multiJoin, null);
- if (joinInputs.size() >= 3 && !conjuncts.isEmpty()) {
- return reorderJoinsAccordingToConditions(joinInputs, conjuncts);
- } else {
- return filter;
- }
+ return multiJoin.reorderJoinsAccordingToConditions();
}).toRule(RuleType.REORDER_JOIN);
}
-
- /**
- * Reorder join orders according to join conditions to eliminate cross join.
- * <p/>
- * Let's say we have input join tables: [t1, t2, t3] and
- * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
- * The input join for t1 and t2 is cross join.
- * <p/>
- * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
- * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
- * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
- * <p/>
- * As a result, Join(t1, t3) is an inner join.
- * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
- */
- private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) {
- if (joinInputs.size() == 2) {
- Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1));
- Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
- List<Expression> joinConditions = split.get(true);
- List<Expression> nonJoinConditions = split.get(false);
-
- Optional<Expression> cond;
- if (joinConditions.isEmpty()) {
- cond = Optional.empty();
- } else {
- cond = Optional.of(ExpressionUtils.and(joinConditions));
- }
-
- LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, joinInputs.get(0), joinInputs.get(1));
- if (nonJoinConditions.isEmpty()) {
- return join;
- } else {
- return new LogicalFilter(ExpressionUtils.and(nonJoinConditions), join);
- }
- } else {
- Plan left = joinInputs.get(0);
- List<Plan> candidate = joinInputs.subList(1, joinInputs.size());
-
- List<Slot> leftOutput = left.getOutput();
- Optional<Plan> rightOpt = candidate.stream().filter(right -> {
- List<Slot> rightOutput = right.getOutput();
-
- Set<Slot> joinOutput = getJoinOutput(left, right);
- Optional<Expression> joinCond = conjuncts.stream()
- .filter(expr -> {
- Set<Slot> exprInputSlots = SlotExtractor.extractSlot(expr);
- if (exprInputSlots.isEmpty()) {
- return false;
- }
-
- if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
- return false;
- }
-
- if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
- return false;
- }
-
- return joinOutput.containsAll(exprInputSlots);
- }).findFirst();
- return joinCond.isPresent();
- }).findFirst();
-
- Plan right = rightOpt.orElseGet(() -> candidate.get(1));
- Set<Slot> joinOutput = getJoinOutput(left, right);
- Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput);
- List<Expression> joinConditions = split.get(true);
- List<Expression> nonJoinConditions = split.get(false);
-
- Optional<Expression> cond;
- if (joinConditions.isEmpty()) {
- cond = Optional.empty();
- } else {
- cond = Optional.of(ExpressionUtils.and(joinConditions));
- }
-
- LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, left, right);
-
- List<Plan> newInputs = new ArrayList<>();
- newInputs.add(join);
- newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
- return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
- }
- }
-
- private Set<Slot> getJoinOutput(Plan left, Plan right) {
- HashSet<Slot> joinOutput = new HashSet<>();
- joinOutput.addAll(left.getOutput());
- joinOutput.addAll(right.getOutput());
- return joinOutput;
- }
-
- private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> conjuncts, Set<Slot> slots) {
- return conjuncts.stream().collect(Collectors.partitioningBy(
- // TODO: support non equal to conditions.
- expr -> expr instanceof EqualTo && slots.containsAll(SlotExtractor.extractSlot(expr))));
- }
-
- private class PlanCollector extends PlanVisitor<Void, Void> {
- public final List<Plan> joinInputs = new ArrayList<>();
- public final List<Expression> conjuncts = new ArrayList<>();
-
- @Override
- public Void visit(Plan plan, Void context) {
- for (Plan child : plan.children()) {
- child.accept(this, context);
- }
- return null;
- }
-
- @Override
- public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
- Plan child = filter.child();
- if (child instanceof LogicalJoin) {
- conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates()));
- }
-
- child.accept(this, context);
- return null;
- }
-
- @Override
- public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) {
- if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
- return null;
- }
-
- join.left().accept(this, context);
- join.right().accept(this, context);
-
- join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond)));
- if (!(join.left() instanceof LogicalJoin)) {
- joinInputs.add(join.left());
- }
- if (!(join.right() instanceof LogicalJoin)) {
- joinInputs.add(join.right());
- }
- return null;
- }
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org