You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/10/20 14:23:42 UTC
[doris] branch master updated: [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)
This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 4ae777bfc5 [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)
4ae777bfc5 is described below
commit 4ae777bfc55a8b0fa1721011a088f8d9392d9327
Author: Kikyou1997 <33...@users.noreply.github.com>
AuthorDate: Thu Oct 20 22:23:36 2022 +0800
[fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)
---
.../org/apache/doris/nereids/NereidsPlanner.java | 42 +--
.../nereids/jobs/cascades/DeriveStatsJob.java | 5 +-
.../java/org/apache/doris/nereids/memo/Group.java | 3 +-
.../apache/doris/nereids/memo/GroupExpression.java | 11 +-
.../java/org/apache/doris/nereids/memo/Memo.java | 13 +-
.../apache/doris/nereids/trees/plans/FakePlan.java | 104 +++++++
.../apache/doris/nereids/memo/MemoCopyInTest.java | 84 ------
.../apache/doris/nereids/memo/MemoInitTest.java | 182 ------------
.../memo/{MemoRewriteTest.java => MemoTest.java} | 307 ++++++++++++++++++---
.../suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy | 107 +++++++
10 files changed, 520 insertions(+), 338 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
index ff1ec80890..59cb499900 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
@@ -190,25 +190,31 @@ public class NereidsPlanner extends Planner {
private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physicalProperties)
throws AnalysisException {
- GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
- () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second;
- List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);
-
- List<Plan> planChildren = Lists.newArrayList();
- for (int i = 0; i < groupExpression.arity(); i++) {
- planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i)));
- }
-
- Plan plan = groupExpression.getPlan().withChildren(planChildren);
- if (!(plan instanceof PhysicalPlan)) {
- throw new AnalysisException("Result plan must be PhysicalPlan");
+ try {
+ GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
+ () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second;
+ List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);
+
+ List<Plan> planChildren = Lists.newArrayList();
+ for (int i = 0; i < groupExpression.arity(); i++) {
+ planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i)));
+ }
+
+ Plan plan = groupExpression.getPlan().withChildren(planChildren);
+ if (!(plan instanceof PhysicalPlan)) {
+ throw new AnalysisException("Result plan must be PhysicalPlan");
+ }
+
+ // TODO: set (logical and physical)properties/statistics/... for physicalPlan.
+ PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
+ groupExpression.getOutputProperties(physicalProperties),
+ groupExpression.getOwnerGroup().getStatistics());
+ return physicalPlan;
+ } catch (Exception e) {
+ String memo = cascadesContext.getMemo().toString();
+ LOG.warn("Failed to choose best plan, memo structure:{}", memo, e);
+ throw new AnalysisException("Failed to choose best plan", e);
}
-
- // TODO: set (logical and physical)properties/statistics/... for physicalPlan.
- PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
- groupExpression.getOutputProperties(physicalProperties),
- groupExpression.getOwnerGroup().getStatistics());
- return physicalPlan;
}
@Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
index 618382438a..8797327311 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
@@ -60,9 +60,8 @@ public class DeriveStatsJob extends Job {
deriveChildren = true;
pushJob(new DeriveStatsJob(this));
for (Group child : groupExpression.children()) {
- GroupExpression childGroupExpr = child.getLogicalExpressions().get(0);
- if (!child.getLogicalExpressions().isEmpty() && !childGroupExpr.isStatDerived()) {
- pushJob(new DeriveStatsJob(childGroupExpr, context));
+ if (!child.getLogicalExpressions().isEmpty()) {
+ pushJob(new DeriveStatsJob(child.getLogicalExpressions().get(0), context));
}
}
} else {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
index 6a65f4a373..fbb0a0d32b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
@@ -429,9 +429,8 @@ public class Group {
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
GroupExpression bestGroupExpression = costAndGroupExpr.second;
// change into target group.
- if (bestGroupExpression.getOwnerGroup() == this) {
+ if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
bestGroupExpression.setOwnerGroup(target);
- bestGroupExpression.children().set(0, target);
}
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
index a6f555b649..e8de9a6e54 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
@@ -229,11 +229,20 @@ public class GroupExpression {
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=[");
+ if (ownerGroup == null) {
+ builder.append("OWNER GROUP IS NULL[]");
+ } else {
+ builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan.toString()).append(") children=[");
+ }
for (Group group : children) {
builder.append(group.getGroupId()).append(" ");
}
builder.append("] stats=");
- builder.append(ownerGroup.getStatistics());
+ if (ownerGroup != null) {
+ builder.append(ownerGroup.getStatistics());
+ } else {
+ builder.append("NULL");
+ }
return builder.toString();
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index a484584a49..abafce3880 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -47,6 +47,11 @@ public class Memo {
private final Map<GroupExpression, GroupExpression> groupExpressions = Maps.newHashMap();
private final Group root;
+ // FOR TEST ONLY
+ public Memo() {
+ root = null;
+ }
+
public Memo(Plan plan) {
root = init(plan);
}
@@ -325,20 +330,20 @@ public class Memo {
* @param destination destination group
* @return merged group
*/
- private Group mergeGroup(Group source, Group destination) {
+ public Group mergeGroup(Group source, Group destination) {
if (source.equals(destination)) {
return source;
}
List<GroupExpression> needReplaceChild = Lists.newArrayList();
- groupExpressions.values().forEach(groupExpression -> {
+ for (GroupExpression groupExpression : groupExpressions.values()) {
if (groupExpression.children().contains(source)) {
if (groupExpression.getOwnerGroup().equals(destination)) {
// cycle, we should not merge
- return;
+ return null;
}
needReplaceChild.add(groupExpression);
}
- });
+ }
for (GroupExpression groupExpression : needReplaceChild) {
groupExpressions.remove(groupExpression);
List<Group> children = groupExpression.children();
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java
new file mode 100644
index 0000000000..2058ad37d2
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans;
+
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.properties.LogicalProperties;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Used for unit test only.
+ */
+public class FakePlan implements Plan {
+
+ @Override
+ public List<Plan> children() {
+ return null;
+ }
+
+ @Override
+ public Plan child(int index) {
+ return null;
+ }
+
+ @Override
+ public int arity() {
+ return 0;
+ }
+
+ @Override
+ public Plan withChildren(List<Plan> children) {
+ return null;
+ }
+
+ @Override
+ public PlanType getType() {
+ return null;
+ }
+
+ @Override
+ public Optional<GroupExpression> getGroupExpression() {
+ return Optional.empty();
+ }
+
+ @Override
+ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+ return null;
+ }
+
+ @Override
+ public List<? extends Expression> getExpressions() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public LogicalProperties getLogicalProperties() {
+ return new LogicalProperties(ArrayList::new);
+ }
+
+ @Override
+ public boolean canBind() {
+ return false;
+ }
+
+ @Override
+ public List<Slot> getOutput() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public String treeString() {
+ return "DUMMY";
+ }
+
+ @Override
+ public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
+ return this;
+ }
+
+ @Override
+ public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
+ return this;
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java
deleted file mode 100644
index 5cf28d4528..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java
+++ /dev/null
@@ -1,84 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.memo;
-
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PatternMatchSupported;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.Test;
-
-public class MemoCopyInTest implements PatternMatchSupported {
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
- PlanConstructor.newLogicalOlapScan(0, "A", 0), PlanConstructor.newLogicalOlapScan(1, "B", 0));
- LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
- JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
-
- /**
- * Original:
- * Group 0: LogicalOlapScan C
- * Group 1: LogicalOlapScan B
- * Group 2: LogicalOlapScan A
- * Group 3: Join(Group 1, Group 2)
- * Group 4: Join(Group 0, Group 3)
- * <p>
- * Then:
- * Copy In Join(Group 2, Group 1) into Group 3
- * <p>
- * Expected:
- * Group 0: LogicalOlapScan C
- * Group 1: LogicalOlapScan B
- * Group 2: LogicalOlapScan A
- * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1)
- * Group 4: Join(Group 0, Group 3)
- */
- @Test
- public void testInsertSameGroup() {
- PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC)
- .transform(
- // swap join's children
- logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA ->
- new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()),
- new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left()))
- ))
- .checkGroupNum(6)
- .checkGroupExpressionNum(7)
- .checkMemo(memo -> {
- Group root = memo.getRoot();
- Assertions.assertEquals(1, root.getLogicalExpressions().size());
- GroupExpression joinABC = root.getLogicalExpression();
- Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size());
- Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size());
- GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0);
- GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1);
- GroupExpression joinBA = project.child(0).getLogicalExpression();
- Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin);
- Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin);
- });
-
- }
-
- // TODO: test mergeGroup().
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java
deleted file mode 100644
index 81f13afc80..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java
+++ /dev/null
@@ -1,182 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.memo;
-
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.common.IdGenerator;
-import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.RelationId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PatternMatchSupported;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.qe.ConnectContext;
-
-import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Test;
-
-import java.util.Objects;
-
-public class MemoInitTest implements PatternMatchSupported {
- private ConnectContext connectContext = MemoTestUtils.createConnectContext();
-
- @Test
- public void initByOneLevelPlan() {
- OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
- LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
- PlanChecker.from(connectContext, scan)
- .checkGroupNum(1)
- .matches(
- logicalOlapScan().when(scan::equals)
- );
- }
-
- @Test
- public void initByTwoLevelChainPlan() {
- OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
- LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
- LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>(
- ImmutableList.of(scan.computeOutput().get(0)), scan);
-
- PlanChecker.from(connectContext, topProject)
- .checkGroupNum(2)
- .matches(
- logicalProject(
- any().when(child -> Objects.equals(child, scan))
- ).when(root -> Objects.equals(root, topProject))
- );
- }
-
- @Test
- public void initByJoinSameUnboundTable() {
- UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a"));
-
- LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA);
-
- PlanChecker.from(connectContext, topJoin)
- .checkGroupNum(3)
- .matches(
- logicalJoin(
- any().when(left -> Objects.equals(left, scanA)),
- any().when(right -> Objects.equals(right, scanA))
- ).when(root -> Objects.equals(root, topJoin))
- );
- }
-
- @Test
- public void initByJoinSameLogicalTable() {
- IdGenerator<RelationId> generator = RelationId.createGenerator();
- OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
- LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
- LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA);
-
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1);
-
- PlanChecker.from(connectContext, topJoin)
- .checkGroupNum(3)
- .matches(
- logicalJoin(
- any().when(left -> Objects.equals(left, scanA)),
- any().when(right -> Objects.equals(right, scanA1))
- ).when(root -> Objects.equals(root, topJoin))
- );
- }
-
- @Test
- public void initByTwoLevelJoinPlan() {
- IdGenerator<RelationId> generator = RelationId.createGenerator();
- OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
- OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
- LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
- LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
-
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB);
-
- PlanChecker.from(connectContext, topJoin)
- .checkGroupNum(3)
- .matches(
- logicalJoin(
- any().when(left -> Objects.equals(left, scanA)),
- any().when(right -> Objects.equals(right, scanB))
- ).when(root -> Objects.equals(root, topJoin))
- );
- }
-
- @Test
- public void initByThreeLevelChainPlan() {
- OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
- LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
- LogicalProject<LogicalOlapScan> project = new LogicalProject<>(
- ImmutableList.of(scan.computeOutput().get(0)), scan);
- LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>(
- new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project);
-
- PlanChecker.from(connectContext, filter)
- .checkGroupNum(3)
- .matches(
- logicalFilter(
- logicalProject(
- any().when(child -> Objects.equals(child, scan))
- ).when(root -> Objects.equals(root, project))
- ).when(root -> Objects.equals(root, filter))
- );
- }
-
- @Test
- public void initByThreeLevelBushyPlan() {
- IdGenerator<RelationId> generator = RelationId.createGenerator();
- OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
- OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
- OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1);
- OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1);
- LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
- LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
- LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC);
- LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD);
-
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB);
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD);
- LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin);
-
- PlanChecker.from(connectContext, topJoin)
- .checkGroupNum(7)
- .matches(
- logicalJoin(
- logicalJoin(
- any().when(child -> Objects.equals(child, scanA)),
- any().when(child -> Objects.equals(child, scanB))
- ).when(left -> Objects.equals(left, leftJoin)),
-
- logicalJoin(
- any().when(child -> Objects.equals(child, scanC)),
- any().when(child -> Objects.equals(child, scanD))
- ).when(right -> Objects.equals(right, rightJoin))
- ).when(root -> Objects.equals(root, topJoin))
- );
- }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
similarity index 74%
rename from fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java
rename to fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
index 7a82c5dcd2..9ad9bf8d16 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
@@ -17,17 +17,24 @@
package org.apache.doris.nereids.memo;
+import org.apache.doris.catalog.OlapTable;
+import org.apache.doris.common.IdGenerator;
+import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.properties.LogicalProperties;
+import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.FakePlan;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -45,13 +52,235 @@ import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import java.util.Optional;
-public class MemoRewriteTest implements PatternMatchSupported {
+class MemoTest implements PatternMatchSupported {
+
private ConnectContext connectContext = MemoTestUtils.createConnectContext();
+ private LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
+ PlanConstructor.newLogicalOlapScan(0, "A", 0),
+ PlanConstructor.newLogicalOlapScan(1, "B", 0));
+
+ private LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
+ JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
+
+ @Test
+ void mergeGroup() throws Exception {
+ Memo memo = new Memo();
+ GroupId gid2 = new GroupId(2);
+ Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
+ GroupId gid3 = new GroupId(3);
+ Group dstGroup = new Group(gid3, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
+ FakePlan d = new FakePlan();
+ GroupExpression ge1 = new GroupExpression(d, Arrays.asList(srcGroup));
+ GroupId gid0 = new GroupId(0);
+ Group g1 = new Group(gid0, ge1, new LogicalProperties(ArrayList::new));
+ g1.setBestPlan(ge1, Double.MIN_VALUE, PhysicalProperties.ANY);
+ GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
+ GroupId gid1 = new GroupId(1);
+ Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
+ Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups");
+ groups.put(gid2, srcGroup);
+ groups.put(gid3, dstGroup);
+ groups.put(gid0, g1);
+ groups.put(gid1, g2);
+ Map<GroupExpression, GroupExpression> groupExpressions =
+ (Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions");
+ groupExpressions.put(ge1, ge1);
+ groupExpressions.put(ge2, ge2);
+ memo.mergeGroup(srcGroup, dstGroup);
+ Assertions.assertNull(g1.getBestPlan(PhysicalProperties.ANY));
+ Assertions.assertEquals(ge1.getOwnerGroup(), g2);
+ }
+
+ /**
+ * Original:
+ * Group 0: LogicalOlapScan C
+ * Group 1: LogicalOlapScan B
+ * Group 2: LogicalOlapScan A
+ * Group 3: Join(Group 1, Group 2)
+ * Group 4: Join(Group 0, Group 3)
+ * <p>
+ * Then:
+ * Copy In Join(Group 2, Group 1) into Group 3
+ * <p>
+ * Expected:
+ * Group 0: LogicalOlapScan C
+ * Group 1: LogicalOlapScan B
+ * Group 2: LogicalOlapScan A
+ * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1)
+ * Group 4: Join(Group 0, Group 3)
+ */
+ @Test
+ public void testInsertSameGroup() {
+ PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC)
+ .transform(
+ // swap join's children
+ logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA ->
+ new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()),
+ new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left()))
+ ))
+ .checkGroupNum(6)
+ .checkGroupExpressionNum(7)
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(1, root.getLogicalExpressions().size());
+ GroupExpression joinABC = root.getLogicalExpression();
+ Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size());
+ Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size());
+ GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0);
+ GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1);
+ GroupExpression joinBA = project.child(0).getLogicalExpression();
+ Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin);
+ Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin);
+ });
+
+ }
+
+ @Test
+ public void initByOneLevelPlan() {
+ OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+ LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+ PlanChecker.from(connectContext, scan)
+ .checkGroupNum(1)
+ .matches(
+ logicalOlapScan().when(scan::equals)
+ );
+ }
+
+ @Test
+ public void initByTwoLevelChainPlan() {
+ OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+ LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+ LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>(
+ ImmutableList.of(scan.computeOutput().get(0)), scan);
+
+ PlanChecker.from(connectContext, topProject)
+ .checkGroupNum(2)
+ .matches(
+ logicalProject(
+ any().when(child -> Objects.equals(child, scan))
+ ).when(root -> Objects.equals(root, topProject))
+ );
+ }
+
+ @Test
+ public void initByJoinSameUnboundTable() {
+ UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a"));
+
+ LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA);
+
+ PlanChecker.from(connectContext, topJoin)
+ .checkGroupNum(3)
+ .matches(
+ logicalJoin(
+ any().when(left -> Objects.equals(left, scanA)),
+ any().when(right -> Objects.equals(right, scanA))
+ ).when(root -> Objects.equals(root, topJoin))
+ );
+ }
+
+ @Test
+ public void initByJoinSameLogicalTable() {
+ IdGenerator<RelationId> generator = RelationId.createGenerator();
+ OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+ LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+ LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA);
+
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1);
+
+ PlanChecker.from(connectContext, topJoin)
+ .checkGroupNum(3)
+ .matches(
+ logicalJoin(
+ any().when(left -> Objects.equals(left, scanA)),
+ any().when(right -> Objects.equals(right, scanA1))
+ ).when(root -> Objects.equals(root, topJoin))
+ );
+ }
+
+ @Test
+ public void initByTwoLevelJoinPlan() {
+ IdGenerator<RelationId> generator = RelationId.createGenerator();
+ OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+ OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
+ LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+ LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
+
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB);
+
+ PlanChecker.from(connectContext, topJoin)
+ .checkGroupNum(3)
+ .matches(
+ logicalJoin(
+ any().when(left -> Objects.equals(left, scanA)),
+ any().when(right -> Objects.equals(right, scanB))
+ ).when(root -> Objects.equals(root, topJoin))
+ );
+ }
+
+ @Test
+ public void initByThreeLevelChainPlan() {
+ OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+ LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+ LogicalProject<LogicalOlapScan> project = new LogicalProject<>(
+ ImmutableList.of(scan.computeOutput().get(0)), scan);
+ LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>(
+ new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project);
+
+ PlanChecker.from(connectContext, filter)
+ .checkGroupNum(3)
+ .matches(
+ logicalFilter(
+ logicalProject(
+ any().when(child -> Objects.equals(child, scan))
+ ).when(root -> Objects.equals(root, project))
+ ).when(root -> Objects.equals(root, filter))
+ );
+ }
+
+ @Test
+ public void initByThreeLevelBushyPlan() {
+ IdGenerator<RelationId> generator = RelationId.createGenerator();
+ OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+ OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
+ OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1);
+ OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1);
+ LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+ LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
+ LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC);
+ LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD);
+
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB);
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD);
+ LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin);
+
+ PlanChecker.from(connectContext, topJoin)
+ .checkGroupNum(7)
+ .matches(
+ logicalJoin(
+ logicalJoin(
+ any().when(child -> Objects.equals(child, scanA)),
+ any().when(child -> Objects.equals(child, scanB))
+ ).when(left -> Objects.equals(left, leftJoin)),
+
+ logicalJoin(
+ any().when(child -> Objects.equals(child, scanC)),
+ any().when(child -> Objects.equals(child, scanD))
+ ).when(right -> Objects.equals(right, rightJoin))
+ ).when(root -> Objects.equals(root, topJoin))
+ );
+ }
+
/*
* A -> A:
*
@@ -204,25 +433,15 @@ public class MemoRewriteTest implements PatternMatchSupported {
A a2 = new A(ImmutableList.of("student"), State.ALREADY_REWRITE);
LogicalLimit<UnboundRelation> limit = new LogicalLimit<>(1, 0, a2);
- PlanChecker.from(connectContext, a)
- .applyBottomUp(
- unboundRelation()
- // 4: add state condition to the pattern's predicates
- .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE)
- .then(unboundRelation -> {
- // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has
- // different state with A
- A notRewritePlan = (A) unboundRelation;
- return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE));
- }
- )
- )
- .checkGroupNum(2)
- .matchesFromRoot(
- logicalLimit(
- unboundRelation().when(a2::equals)
- ).when(limit::equals)
- );
+ PlanChecker.from(connectContext, a).applyBottomUp(unboundRelation()
+ // 4: add state condition to the pattern's predicates
+ .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE).then(unboundRelation -> {
+ // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has
+ // different state with A
+ A notRewritePlan = (A) unboundRelation;
+ return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE));
+ })).checkGroupNum(2)
+ .matchesFromRoot(logicalLimit(unboundRelation().when(a2::equals)).when(limit::equals));
}
/*
@@ -359,7 +578,7 @@ public class MemoRewriteTest implements PatternMatchSupported {
)
.checkGroupNum(1)
.matchesFromRoot(
- logicalOlapScan().when(student::equals)
+ logicalOlapScan().when(student::equals)
);
}
@@ -801,30 +1020,30 @@ public class MemoRewriteTest implements PatternMatchSupported {
)
))
.applyTopDown(
- logicalLimit(logicalJoin()).then(limit -> {
- LogicalJoin<GroupPlan, GroupPlan> join = limit.child();
- switch (join.getJoinType()) {
- case LEFT_OUTER_JOIN:
- return join.withChildren(limit.withChildren(join.left()), join.right());
- case RIGHT_OUTER_JOIN:
- return join.withChildren(join.left(), limit.withChildren(join.right()));
- case CROSS_JOIN:
- return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right()));
- case INNER_JOIN:
- if (!join.getHashJoinConjuncts().isEmpty()) {
- return join.withChildren(
- limit.withChildren(join.left()),
- limit.withChildren(join.right())
- );
- } else {
+ logicalLimit(logicalJoin()).then(limit -> {
+ LogicalJoin<GroupPlan, GroupPlan> join = limit.child();
+ switch (join.getJoinType()) {
+ case LEFT_OUTER_JOIN:
+ return join.withChildren(limit.withChildren(join.left()), join.right());
+ case RIGHT_OUTER_JOIN:
+ return join.withChildren(join.left(), limit.withChildren(join.right()));
+ case CROSS_JOIN:
+ return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right()));
+ case INNER_JOIN:
+ if (!join.getHashJoinConjuncts().isEmpty()) {
+ return join.withChildren(
+ limit.withChildren(join.left()),
+ limit.withChildren(join.right())
+ );
+ } else {
+ return limit;
+ }
+ case LEFT_ANTI_JOIN:
+ // todo: support anti join.
+ default:
return limit;
- }
- case LEFT_ANTI_JOIN:
- // todo: support anti join.
- default:
- return limit;
- }
- })
+ }
+ })
)
.matchesFromRoot(
logicalJoin(
diff --git a/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy
new file mode 100644
index 0000000000..0407830d15
--- /dev/null
+++ b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("tpch_sf1_q21_nereids") {
+ String realDb = context.config.getDbNameByFile(context.file)
+ // get parent directory's group
+ realDb = realDb.substring(0, realDb.lastIndexOf("_"))
+
+ sql "use ${realDb}"
+
+ sql 'set enable_nereids_planner=true'
+ sql 'set enable_fallback_to_original_planner=false'
+
+ qt_select """
+select
+ s_name,
+ count(*) as numwait
+from
+ supplier,
+ lineitem l1,
+ orders,
+ nation
+where
+ s_suppkey = l1.l_suppkey
+ and o_orderkey = l1.l_orderkey
+ and o_orderstatus = 'F'
+ and l1.l_receiptdate > l1.l_commitdate
+ and exists (
+ select
+ *
+ from
+ lineitem l2
+ where
+ l2.l_orderkey = l1.l_orderkey
+ and l2.l_suppkey <> l1.l_suppkey
+ )
+ and not exists (
+ select
+ *
+ from
+ lineitem l3
+ where
+ l3.l_orderkey = l1.l_orderkey
+ and l3.l_suppkey <> l1.l_suppkey
+ and l3.l_receiptdate > l3.l_commitdate
+ )
+ and s_nationkey = n_nationkey
+ and n_name = 'SAUDI ARABIA'
+group by
+ s_name
+order by
+ numwait desc,
+ s_name
+limit 100;
+ """
+
+ qt_select """
+select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=16, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */
+s_name, count(*) as numwait
+from orders join
+(
+ select * from
+ lineitem l2 right semi join
+ (
+ select * from
+ lineitem l3 right anti join
+ (
+ select * from
+ lineitem l1 join
+ (
+ select * from
+ supplier join nation
+ where s_nationkey = n_nationkey
+ and n_name = 'SAUDI ARABIA'
+ ) t1
+ where t1.s_suppkey = l1.l_suppkey and l1.l_receiptdate > l1.l_commitdate
+ ) t2
+ on l3.l_orderkey = t2.l_orderkey and l3.l_suppkey <> t2.l_suppkey and l3.l_receiptdate > l3.l_commitdate
+ ) t3
+ on l2.l_orderkey = t3.l_orderkey and l2.l_suppkey <> t3.l_suppkey
+) t4
+on o_orderkey = t4.l_orderkey and o_orderstatus = 'F'
+group by
+ t4.s_name
+order by
+ numwait desc,
+ t4.s_name
+limit 100;
+ """
+
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org