You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ja...@apache.org on 2023/11/01 11:05:15 UTC
(doris) branch master updated: [enhancement](Nereids): optimize GroupExpressionMatching (#26196)
This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 6010be88bd4 [enhancement](Nereids): optimize GroupExpressionMatching (#26196)
6010be88bd4 is described below
commit 6010be88bd44139c7d27cede7166f569e5f2f76b
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Nov 1 19:05:08 2023 +0800
[enhancement](Nereids): optimize GroupExpressionMatching (#26196)
---
.../doris/nereids/jobs/cascades/ApplyRuleJob.java | 2 +-
.../nereids/pattern/GroupExpressionMatching.java | 39 +++++++++++-----------
.../doris/nereids/trees/AbstractTreeNode.java | 13 --------
.../nereids/trees/expressions/Expression.java | 15 +++++++--
.../doris/nereids/trees/plans/AbstractPlan.java | 11 ++++++
.../trees/plans/physical/PhysicalHashJoin.java | 15 +++++----
.../org/apache/doris/nereids/util/JoinUtils.java | 12 ++++---
.../org/apache/doris/nereids/util/PlanUtils.java | 2 +-
.../pattern/GroupExpressionMatchingTest.java | 26 +++++++--------
.../nereids_p0/expression/topn_to_max.groovy | 4 +--
10 files changed, 76 insertions(+), 63 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java
index 3e73850b015..5560c369dd6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java
@@ -61,7 +61,7 @@ public class ApplyRuleJob extends Job {
}
@Override
- public void execute() throws AnalysisException {
+ public final void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java
index f73ddcc8868..e281e74a339 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java
@@ -55,6 +55,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
public static class GroupExpressionIterator implements Iterator<Plan> {
private final List<Plan> results = Lists.newArrayList();
private int resultIndex = 0;
+ private int resultsSize;
/**
* Constructor.
@@ -103,7 +104,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
// matching children group, one List<Plan> per child
// first dimension is every child group's plan
// second dimension is all matched plan in one group
- List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
+ List<Plan>[] childrenPlans = new List[childrenGroupArity];
for (int i = 0; i < childrenGroupArity; ++i) {
Group childGroup = groupExpression.child(i);
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
@@ -116,7 +117,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
return;
}
}
- childrenPlans.add(childrenPlan);
+ childrenPlans[i] = childrenPlan;
}
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
@@ -127,6 +128,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
results.add(root);
}
}
+ this.resultsSize = results.size();
}
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
@@ -154,38 +156,35 @@ public class GroupExpressionMatching implements Iterable<Plan> {
}
private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
- GroupExpression groupExpression,
- List<List<Plan>> childrenPlans) {
- int[] childrenPlanIndex = new int[childrenPlans.size()];
+ GroupExpression groupExpression, List<Plan>[] childrenPlans) {
+ int childrenPlansSize = childrenPlans.length;
+ int[] childrenPlanIndex = new int[childrenPlansSize];
int offset = 0;
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();
// assemble all combination of plan tree by current root plan and children plan
- while (offset < childrenPlans.size()) {
- ImmutableList.Builder<Plan> childrenBuilder =
- ImmutableList.builderWithExpectedSize(childrenPlans.size());
- for (int i = 0; i < childrenPlans.size(); i++) {
- childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
+ Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
+ Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
+ while (offset < childrenPlansSize) {
+ ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
+ for (int i = 0; i < childrenPlansSize; i++) {
+ childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
}
List<Plan> children = childrenBuilder.build();
// assemble children: replace GroupPlan to real plan,
// withChildren will erase groupExpression, so we must
// withGroupExpression too.
- Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
- Optional.of(logicalProperties), children);
+ Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
+ logicalPropOption, children);
if (rootPattern.matchPredicates(rootWithChildren)) {
results.add(rootWithChildren);
}
- offset = 0;
- while (true) {
+ for (offset = 0; offset < childrenPlansSize; offset++) {
childrenPlanIndex[offset]++;
- if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
+ if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
+ // Reset the index when it reaches the size of the current child plan list
childrenPlanIndex[offset] = 0;
- offset++;
- if (offset == childrenPlans.size()) {
- break;
- }
} else {
break;
}
@@ -195,7 +194,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
@Override
public boolean hasNext() {
- return resultIndex < results.size();
+ return resultIndex < resultsSize;
}
@Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java
index 0305ae2afad..7a545ec17be 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java
@@ -17,10 +17,6 @@
package org.apache.doris.nereids.trees;
-import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
-import org.apache.doris.nereids.trees.plans.ObjectId;
-import org.apache.doris.planner.PlanNodeId;
-
import com.google.common.collect.ImmutableList;
import java.util.List;
@@ -33,7 +29,6 @@ import java.util.List;
*/
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
implements TreeNode<NODE_TYPE> {
- protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final List<NODE_TYPE> children;
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
// https://github.com/apache/doris/pull/9807#discussion_r884829067
@@ -59,12 +54,4 @@ public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
public int arity() {
return children.size();
}
-
- /**
- * used for PhysicalPlanTranslator only
- * @return PlanNodeId
- */
- public PlanNodeId translatePlanNodeId() {
- return id.toPlanNodeId();
- }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
index 3f0370d7c3b..12a3a9768ca 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
@@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
@@ -44,7 +45,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
-import java.util.stream.Collectors;
/**
* Abstract class for all Expression in Nereids.
@@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return collect(Slot.class::isInstance);
}
+ /**
+ * Get all the input slot ids of the expression.
+ * <p>
+ * Note that the input slots of subquery's inner plan is not included.
+ */
public final Set<ExprId> getInputSlotExprIds() {
- return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
+ ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
+ foreach(node -> {
+ if (node instanceof Slot) {
+ result.add(((Slot) node).getExprId());
+ }
+ });
+ return result.build();
}
public boolean isLiteral() {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java
index 38a209ff55f..c223dd43b6e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java
@@ -24,9 +24,11 @@ import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
import org.apache.doris.nereids.util.TreeStringUtils;
+import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Supplier;
@@ -45,6 +47,7 @@ import javax.annotation.Nullable;
*/
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
public static final String FRAGMENT_ID = "fragment";
+ protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final Statistics statistics;
protected final PlanType type;
@@ -168,4 +171,12 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
public void setMutableState(String key, Object state) {
this.mutableState = this.mutableState.set(key, state);
}
+
+ /**
+ * used for PhysicalPlanTranslator only
+ * @return PlanNodeId
+ */
+ public PlanNodeId translatePlanNodeId() {
+ return id.toPlanNodeId();
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
index b60afd67308..994b4d4f971 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
@@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -113,22 +113,25 @@ public class PhysicalHashJoin<
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
- List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
- List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
+ int size = hashJoinConjuncts.size();
+
+ List<ExprId> exprIds1 = new ArrayList<>(size);
+ List<ExprId> exprIds2 = new ArrayList<>(size);
Set<ExprId> leftExprIds = left().getOutputExprIdSet();
Set<ExprId> rightExprIds = right().getOutputExprIdSet();
for (Expression expr : hashJoinConjuncts) {
- expr.getInputSlotExprIds().forEach(exprId -> {
+ for (ExprId exprId : expr.getInputSlotExprIds()) {
if (leftExprIds.contains(exprId)) {
exprIds1.add(exprId);
} else if (rightExprIds.contains(exprId)) {
exprIds2.add(exprId);
} else {
- throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
+ throw new RuntimeException("Invalid ExprId found: " + exprId
+ + ". Cannot generate valid equal on clause slot pairs for join.");
}
- });
+ }
}
return Pair.of(exprIds1, exprIds2);
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
index d1fb973dd61..bcf53ce29f8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
@@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@@ -66,9 +67,10 @@ public class JoinUtils {
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
- double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
- double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
- double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
+ SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
+ double memLimit = sessionVariable.getMaxExecMemByte();
+ double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
+ double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
@@ -114,12 +116,12 @@ public class JoinUtils {
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualTo equalTo) {
- Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
+ Set<Slot> equalLeft = equalTo.left().getInputSlots();
if (equalLeft.isEmpty()) {
return false;
}
- Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
+ Set<Slot> equalRight = equalTo.right().getInputSlots();
if (equalRight.isEmpty()) {
return false;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 17034c15e6e..48eb452a74c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -55,7 +55,7 @@ public class PlanUtils {
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
*/
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
- Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
+ Set<Slot> slots = expression.left().getInputSlots();
Set<Slot> leftSlots = left.getOutputSet();
Set<Slot> buffer = Sets.newHashSet(slots);
buffer.removeAll(leftSlots);
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java
index 53a459859b2..6a4d38b5adb 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java
@@ -42,10 +42,10 @@ import org.junit.jupiter.api.Test;
import java.util.Iterator;
-public class GroupExpressionMatchingTest {
+class GroupExpressionMatchingTest {
@Test
- public void testLeafNode() {
+ void testLeafNode() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION);
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
@@ -61,7 +61,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testDepth2() {
+ void testDepth2() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT,
new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION));
@@ -93,7 +93,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testDepth2WithGroup() {
+ void testDepth2WithGroup() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, Pattern.GROUP);
Plan leaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"));
@@ -119,7 +119,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testLeafAny() {
+ void testLeafAny() {
Pattern pattern = Pattern.ANY;
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
@@ -135,7 +135,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testAnyWithChild() {
+ void testAnyWithChild() {
Plan root = new LogicalProject(
ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true,
ImmutableList.of("test"))),
@@ -159,7 +159,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testInnerLogicalJoinMatch() {
+ void testInnerLogicalJoinMatch() {
Plan root = new LogicalJoin(JoinType.INNER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@@ -181,7 +181,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testInnerLogicalJoinMismatch() {
+ void testInnerLogicalJoinMismatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@@ -198,7 +198,7 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testTopMatchButChildrenNotMatch() {
+ void testTopMatchButChildrenNotMatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@@ -216,12 +216,12 @@ public class GroupExpressionMatchingTest {
}
@Test
- public void testSubTreeMatch() {
+ void testSubTreeMatch() {
Plan root =
- new LogicalFilter(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
+ new LogicalFilter<>(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
new UnboundSlot(Lists.newArrayList("b", "id")))),
- new LogicalJoin(JoinType.INNER_JOIN,
- new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
+ new LogicalJoin<>(JoinType.INNER_JOIN,
+ new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("c")))
diff --git a/regression-test/suites/nereids_p0/expression/topn_to_max.groovy b/regression-test/suites/nereids_p0/expression/topn_to_max.groovy
index ae848b5a244..4c05b42ccc3 100644
--- a/regression-test/suites/nereids_p0/expression/topn_to_max.groovy
+++ b/regression-test/suites/nereids_p0/expression/topn_to_max.groovy
@@ -31,7 +31,7 @@ suite("test_topn_to_max") {
group by k1;
'''
res = sql '''
- explain rewritten plan select k1, max(k2)
+ explain rewritten plan select k1, topn(k2, 1)
from test_topn_to_max
group by k1;
'''
@@ -42,7 +42,7 @@ suite("test_topn_to_max") {
from test_topn_to_max;
'''
res = sql '''
- explain rewritten plan select max(k2)
+ explain rewritten plan select topn(k2, 1)
from test_topn_to_max;
'''
assertTrue(res.toString().contains("max"))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org