You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/10/20 14:23:42 UTC

[doris] branch master updated: [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)

This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ae777bfc5 [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)
4ae777bfc5 is described below

commit 4ae777bfc55a8b0fa1721011a088f8d9392d9327
Author: Kikyou1997 <33...@users.noreply.github.com>
AuthorDate: Thu Oct 20 22:23:36 2022 +0800

    [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252)
---
 .../org/apache/doris/nereids/NereidsPlanner.java   |  42 +--
 .../nereids/jobs/cascades/DeriveStatsJob.java      |   5 +-
 .../java/org/apache/doris/nereids/memo/Group.java  |   3 +-
 .../apache/doris/nereids/memo/GroupExpression.java |  11 +-
 .../java/org/apache/doris/nereids/memo/Memo.java   |  13 +-
 .../apache/doris/nereids/trees/plans/FakePlan.java | 104 +++++++
 .../apache/doris/nereids/memo/MemoCopyInTest.java  |  84 ------
 .../apache/doris/nereids/memo/MemoInitTest.java    | 182 ------------
 .../memo/{MemoRewriteTest.java => MemoTest.java}   | 307 ++++++++++++++++++---
 .../suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy | 107 +++++++
 10 files changed, 520 insertions(+), 338 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
index ff1ec80890..59cb499900 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
@@ -190,25 +190,31 @@ public class NereidsPlanner extends Planner {
 
     private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physicalProperties)
             throws AnalysisException {
-        GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
-                () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second;
-        List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);
-
-        List<Plan> planChildren = Lists.newArrayList();
-        for (int i = 0; i < groupExpression.arity(); i++) {
-            planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i)));
-        }
-
-        Plan plan = groupExpression.getPlan().withChildren(planChildren);
-        if (!(plan instanceof PhysicalPlan)) {
-            throw new AnalysisException("Result plan must be PhysicalPlan");
+        try {
+            GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
+                    () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second;
+            List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);
+
+            List<Plan> planChildren = Lists.newArrayList();
+            for (int i = 0; i < groupExpression.arity(); i++) {
+                planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i)));
+            }
+
+            Plan plan = groupExpression.getPlan().withChildren(planChildren);
+            if (!(plan instanceof PhysicalPlan)) {
+                throw new AnalysisException("Result plan must be PhysicalPlan");
+            }
+
+            // TODO: set (logical and physical)properties/statistics/... for physicalPlan.
+            PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
+                    groupExpression.getOutputProperties(physicalProperties),
+                    groupExpression.getOwnerGroup().getStatistics());
+            return physicalPlan;
+        } catch (Exception e) {
+            String memo = cascadesContext.getMemo().toString();
+            LOG.warn("Failed to choose best plan, memo structure:{}", memo, e);
+            throw new AnalysisException("Failed to choose best plan", e);
         }
-
-        // TODO: set (logical and physical)properties/statistics/... for physicalPlan.
-        PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
-                groupExpression.getOutputProperties(physicalProperties),
-                groupExpression.getOwnerGroup().getStatistics());
-        return physicalPlan;
     }
 
     @Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
index 618382438a..8797327311 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java
@@ -60,9 +60,8 @@ public class DeriveStatsJob extends Job {
             deriveChildren = true;
             pushJob(new DeriveStatsJob(this));
             for (Group child : groupExpression.children()) {
-                GroupExpression childGroupExpr = child.getLogicalExpressions().get(0);
-                if (!child.getLogicalExpressions().isEmpty() && !childGroupExpr.isStatDerived()) {
-                    pushJob(new DeriveStatsJob(childGroupExpr, context));
+                if (!child.getLogicalExpressions().isEmpty()) {
+                    pushJob(new DeriveStatsJob(child.getLogicalExpressions().get(0), context));
                 }
             }
         } else {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
index 6a65f4a373..fbb0a0d32b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
@@ -429,9 +429,8 @@ public class Group {
         lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
             GroupExpression bestGroupExpression = costAndGroupExpr.second;
             // change into target group.
-            if (bestGroupExpression.getOwnerGroup() == this) {
+            if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
                 bestGroupExpression.setOwnerGroup(target);
-                bestGroupExpression.children().set(0, target);
             }
             if (!target.lowestCostPlans.containsKey(physicalProperties)) {
                 target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
index a6f555b649..e8de9a6e54 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
@@ -229,11 +229,20 @@ public class GroupExpression {
     public String toString() {
         StringBuilder builder = new StringBuilder();
         builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=[");
+        if (ownerGroup == null) {
+            builder.append("OWNER GROUP IS NULL[]");
+        } else {
+            builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan.toString()).append(") children=[");
+        }
         for (Group group : children) {
             builder.append(group.getGroupId()).append(" ");
         }
         builder.append("] stats=");
-        builder.append(ownerGroup.getStatistics());
+        if (ownerGroup != null) {
+            builder.append(ownerGroup.getStatistics());
+        } else {
+            builder.append("NULL");
+        }
         return builder.toString();
     }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index a484584a49..abafce3880 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -47,6 +47,11 @@ public class Memo {
     private final Map<GroupExpression, GroupExpression> groupExpressions = Maps.newHashMap();
     private final Group root;
 
+    // FOR TEST ONLY
+    public Memo() {
+        root = null;
+    }
+
     public Memo(Plan plan) {
         root = init(plan);
     }
@@ -325,20 +330,20 @@ public class Memo {
      * @param destination destination group
      * @return merged group
      */
-    private Group mergeGroup(Group source, Group destination) {
+    public Group mergeGroup(Group source, Group destination) {
         if (source.equals(destination)) {
             return source;
         }
         List<GroupExpression> needReplaceChild = Lists.newArrayList();
-        groupExpressions.values().forEach(groupExpression -> {
+        for (GroupExpression groupExpression : groupExpressions.values()) {
             if (groupExpression.children().contains(source)) {
                 if (groupExpression.getOwnerGroup().equals(destination)) {
                     // cycle, we should not merge
-                    return;
+                    return null;
                 }
                 needReplaceChild.add(groupExpression);
             }
-        });
+        }
         for (GroupExpression groupExpression : needReplaceChild) {
             groupExpressions.remove(groupExpression);
             List<Group> children = groupExpression.children();
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java
new file mode 100644
index 0000000000..2058ad37d2
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans;
+
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.properties.LogicalProperties;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Used for unit test only.
+ */
+public class FakePlan implements Plan {
+
+    @Override
+    public List<Plan> children() {
+        return null;
+    }
+
+    @Override
+    public Plan child(int index) {
+        return null;
+    }
+
+    @Override
+    public int arity() {
+        return 0;
+    }
+
+    @Override
+    public Plan withChildren(List<Plan> children) {
+        return null;
+    }
+
+    @Override
+    public PlanType getType() {
+        return null;
+    }
+
+    @Override
+    public Optional<GroupExpression> getGroupExpression() {
+        return Optional.empty();
+    }
+
+    @Override
+    public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+        return null;
+    }
+
+    @Override
+    public List<? extends Expression> getExpressions() {
+        return new ArrayList<>();
+    }
+
+    @Override
+    public LogicalProperties getLogicalProperties() {
+        return new LogicalProperties(ArrayList::new);
+    }
+
+    @Override
+    public boolean canBind() {
+        return false;
+    }
+
+    @Override
+    public List<Slot> getOutput() {
+        return new ArrayList<>();
+    }
+
+    @Override
+    public String treeString() {
+        return "DUMMY";
+    }
+
+    @Override
+    public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
+        return this;
+    }
+
+    @Override
+    public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
+        return this;
+    }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java
deleted file mode 100644
index 5cf28d4528..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java
+++ /dev/null
@@ -1,84 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.memo;
-
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PatternMatchSupported;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.Test;
-
-public class MemoCopyInTest implements PatternMatchSupported {
-    LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
-            PlanConstructor.newLogicalOlapScan(0, "A", 0), PlanConstructor.newLogicalOlapScan(1, "B", 0));
-    LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
-            JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
-
-    /**
-     * Original:
-     * Group 0: LogicalOlapScan C
-     * Group 1: LogicalOlapScan B
-     * Group 2: LogicalOlapScan A
-     * Group 3: Join(Group 1, Group 2)
-     * Group 4: Join(Group 0, Group 3)
-     * <p>
-     * Then:
-     * Copy In Join(Group 2, Group 1) into Group 3
-     * <p>
-     * Expected:
-     * Group 0: LogicalOlapScan C
-     * Group 1: LogicalOlapScan B
-     * Group 2: LogicalOlapScan A
-     * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1)
-     * Group 4: Join(Group 0, Group 3)
-     */
-    @Test
-    public void testInsertSameGroup() {
-        PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC)
-                .transform(
-                        // swap join's children
-                        logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA ->
-                                new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()),
-                                        new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left()))
-                        ))
-                .checkGroupNum(6)
-                .checkGroupExpressionNum(7)
-                .checkMemo(memo -> {
-                    Group root = memo.getRoot();
-                    Assertions.assertEquals(1, root.getLogicalExpressions().size());
-                    GroupExpression joinABC = root.getLogicalExpression();
-                    Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size());
-                    Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size());
-                    GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0);
-                    GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1);
-                    GroupExpression joinBA = project.child(0).getLogicalExpression();
-                    Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin);
-                    Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin);
-                });
-
-    }
-
-    // TODO: test mergeGroup().
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java
deleted file mode 100644
index 81f13afc80..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java
+++ /dev/null
@@ -1,182 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.memo;
-
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.common.IdGenerator;
-import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.RelationId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PatternMatchSupported;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.qe.ConnectContext;
-
-import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Test;
-
-import java.util.Objects;
-
-public class MemoInitTest implements PatternMatchSupported {
-    private ConnectContext connectContext = MemoTestUtils.createConnectContext();
-
-    @Test
-    public void initByOneLevelPlan() {
-        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
-        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
-        PlanChecker.from(connectContext, scan)
-                .checkGroupNum(1)
-                .matches(
-                    logicalOlapScan().when(scan::equals)
-                );
-    }
-
-    @Test
-    public void initByTwoLevelChainPlan() {
-        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
-        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
-        LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>(
-                ImmutableList.of(scan.computeOutput().get(0)), scan);
-
-        PlanChecker.from(connectContext, topProject)
-                .checkGroupNum(2)
-                .matches(
-                        logicalProject(
-                                any().when(child -> Objects.equals(child, scan))
-                        ).when(root -> Objects.equals(root, topProject))
-                );
-    }
-
-    @Test
-    public void initByJoinSameUnboundTable() {
-        UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a"));
-
-        LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA);
-
-        PlanChecker.from(connectContext, topJoin)
-                .checkGroupNum(3)
-                .matches(
-                        logicalJoin(
-                                any().when(left -> Objects.equals(left, scanA)),
-                                any().when(right -> Objects.equals(right, scanA))
-                        ).when(root -> Objects.equals(root, topJoin))
-                );
-    }
-
-    @Test
-    public void initByJoinSameLogicalTable() {
-        IdGenerator<RelationId> generator = RelationId.createGenerator();
-        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
-        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
-        LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA);
-
-        LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1);
-
-        PlanChecker.from(connectContext, topJoin)
-                .checkGroupNum(3)
-                .matches(
-                        logicalJoin(
-                                any().when(left -> Objects.equals(left, scanA)),
-                                any().when(right -> Objects.equals(right, scanA1))
-                        ).when(root -> Objects.equals(root, topJoin))
-                );
-    }
-
-    @Test
-    public void initByTwoLevelJoinPlan() {
-        IdGenerator<RelationId> generator = RelationId.createGenerator();
-        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
-        OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
-        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
-        LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
-
-        LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB);
-
-        PlanChecker.from(connectContext, topJoin)
-                .checkGroupNum(3)
-                .matches(
-                        logicalJoin(
-                                any().when(left -> Objects.equals(left, scanA)),
-                                any().when(right -> Objects.equals(right, scanB))
-                        ).when(root -> Objects.equals(root, topJoin))
-                );
-    }
-
-    @Test
-    public void initByThreeLevelChainPlan() {
-        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
-        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-
-        LogicalProject<LogicalOlapScan> project = new LogicalProject<>(
-                ImmutableList.of(scan.computeOutput().get(0)), scan);
-        LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>(
-                new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project);
-
-        PlanChecker.from(connectContext, filter)
-                .checkGroupNum(3)
-                .matches(
-                        logicalFilter(
-                            logicalProject(
-                                    any().when(child -> Objects.equals(child, scan))
-                            ).when(root -> Objects.equals(root, project))
-                        ).when(root -> Objects.equals(root, filter))
-                );
-    }
-
-    @Test
-    public void initByThreeLevelBushyPlan() {
-        IdGenerator<RelationId> generator = RelationId.createGenerator();
-        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
-        OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
-        OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1);
-        OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1);
-        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
-        LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
-        LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC);
-        LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD);
-
-        LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB);
-        LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD);
-        LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin);
-
-        PlanChecker.from(connectContext, topJoin)
-                .checkGroupNum(7)
-                .matches(
-                        logicalJoin(
-                                logicalJoin(
-                                        any().when(child -> Objects.equals(child, scanA)),
-                                        any().when(child -> Objects.equals(child, scanB))
-                                ).when(left -> Objects.equals(left, leftJoin)),
-
-                                logicalJoin(
-                                        any().when(child -> Objects.equals(child, scanC)),
-                                        any().when(child -> Objects.equals(child, scanD))
-                                ).when(right -> Objects.equals(right, rightJoin))
-                        ).when(root -> Objects.equals(root, topJoin))
-                );
-    }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
similarity index 74%
rename from fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java
rename to fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
index 7a82c5dcd2..9ad9bf8d16 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
@@ -17,17 +17,24 @@
 
 package org.apache.doris.nereids.memo;
 
+import org.apache.doris.catalog.OlapTable;
+import org.apache.doris.common.IdGenerator;
+import org.apache.doris.common.jmockit.Deencapsulation;
 import org.apache.doris.nereids.analyzer.UnboundRelation;
 import org.apache.doris.nereids.analyzer.UnboundSlot;
 import org.apache.doris.nereids.properties.LogicalProperties;
+import org.apache.doris.nereids.properties.PhysicalProperties;
 import org.apache.doris.nereids.properties.UnboundLogicalProperties;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.FakePlan;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.LeafPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -45,13 +52,235 @@ import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 
-public class MemoRewriteTest implements PatternMatchSupported {
+class MemoTest implements PatternMatchSupported {
+
     private ConnectContext connectContext = MemoTestUtils.createConnectContext();
 
+    private LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
+            PlanConstructor.newLogicalOlapScan(0, "A", 0),
+            PlanConstructor.newLogicalOlapScan(1, "B", 0));
+
+    private LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
+            JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
+
+    @Test
+    void mergeGroup() throws Exception {
+        Memo memo = new Memo();
+        GroupId gid2 = new GroupId(2);
+        Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
+        GroupId gid3 = new GroupId(3);
+        Group dstGroup = new Group(gid3, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
+        FakePlan d = new FakePlan();
+        GroupExpression ge1 = new GroupExpression(d, Arrays.asList(srcGroup));
+        GroupId gid0 = new GroupId(0);
+        Group g1 = new Group(gid0, ge1, new LogicalProperties(ArrayList::new));
+        g1.setBestPlan(ge1, Double.MIN_VALUE, PhysicalProperties.ANY);
+        GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
+        GroupId gid1 = new GroupId(1);
+        Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
+        Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups");
+        groups.put(gid2, srcGroup);
+        groups.put(gid3, dstGroup);
+        groups.put(gid0, g1);
+        groups.put(gid1, g2);
+        Map<GroupExpression, GroupExpression> groupExpressions =
+                (Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions");
+        groupExpressions.put(ge1, ge1);
+        groupExpressions.put(ge2, ge2);
+        memo.mergeGroup(srcGroup, dstGroup);
+        Assertions.assertNull(g1.getBestPlan(PhysicalProperties.ANY));
+        Assertions.assertEquals(ge1.getOwnerGroup(), g2);
+    }
+
+    /**
+     * Original:
+     * Group 0: LogicalOlapScan C
+     * Group 1: LogicalOlapScan B
+     * Group 2: LogicalOlapScan A
+     * Group 3: Join(Group 1, Group 2)
+     * Group 4: Join(Group 0, Group 3)
+     * <p>
+     * Then:
+     * Copy In Join(Group 2, Group 1) into Group 3
+     * <p>
+     * Expected:
+     * Group 0: LogicalOlapScan C
+     * Group 1: LogicalOlapScan B
+     * Group 2: LogicalOlapScan A
+     * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1)
+     * Group 4: Join(Group 0, Group 3)
+     */
+    @Test
+    public void testInsertSameGroup() {
+        PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC)
+                .transform(
+                        // swap join's children
+                        logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA ->
+                                new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()),
+                                        new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left()))
+                        ))
+                .checkGroupNum(6)
+                .checkGroupExpressionNum(7)
+                .checkMemo(memo -> {
+                    Group root = memo.getRoot();
+                    Assertions.assertEquals(1, root.getLogicalExpressions().size());
+                    GroupExpression joinABC = root.getLogicalExpression();
+                    Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size());
+                    Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size());
+                    GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0);
+                    GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1);
+                    GroupExpression joinBA = project.child(0).getLogicalExpression();
+                    Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin);
+                    Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin);
+                });
+
+    }
+
+    @Test
+    public void initByOneLevelPlan() {
+        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+        PlanChecker.from(connectContext, scan)
+                .checkGroupNum(1)
+                .matches(
+                        logicalOlapScan().when(scan::equals)
+                );
+    }
+
+    @Test
+    public void initByTwoLevelChainPlan() {
+        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+        LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>(
+                ImmutableList.of(scan.computeOutput().get(0)), scan);
+
+        PlanChecker.from(connectContext, topProject)
+                .checkGroupNum(2)
+                .matches(
+                        logicalProject(
+                                any().when(child -> Objects.equals(child, scan))
+                        ).when(root -> Objects.equals(root, topProject))
+                );
+    }
+
+    @Test
+    public void initByJoinSameUnboundTable() {
+        UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a"));
+
+        LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA);
+
+        PlanChecker.from(connectContext, topJoin)
+                .checkGroupNum(3)
+                .matches(
+                        logicalJoin(
+                                any().when(left -> Objects.equals(left, scanA)),
+                                any().when(right -> Objects.equals(right, scanA))
+                        ).when(root -> Objects.equals(root, topJoin))
+                );
+    }
+
+    @Test
+    public void initByJoinSameLogicalTable() {
+        IdGenerator<RelationId> generator = RelationId.createGenerator();
+        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+        LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA);
+
+        LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1);
+
+        PlanChecker.from(connectContext, topJoin)
+                .checkGroupNum(3)
+                .matches(
+                        logicalJoin(
+                                any().when(left -> Objects.equals(left, scanA)),
+                                any().when(right -> Objects.equals(right, scanA1))
+                        ).when(root -> Objects.equals(root, topJoin))
+                );
+    }
+
+    @Test
+    public void initByTwoLevelJoinPlan() {
+        IdGenerator<RelationId> generator = RelationId.createGenerator();
+        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+        OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
+        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+        LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
+
+        LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB);
+
+        PlanChecker.from(connectContext, topJoin)
+                .checkGroupNum(3)
+                .matches(
+                        logicalJoin(
+                                any().when(left -> Objects.equals(left, scanA)),
+                                any().when(right -> Objects.equals(right, scanB))
+                        ).when(root -> Objects.equals(root, topJoin))
+                );
+    }
+
+    @Test
+    public void initByThreeLevelChainPlan() {
+        OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
+        LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
+
+        LogicalProject<LogicalOlapScan> project = new LogicalProject<>(
+                ImmutableList.of(scan.computeOutput().get(0)), scan);
+        LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>(
+                new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project);
+
+        PlanChecker.from(connectContext, filter)
+                .checkGroupNum(3)
+                .matches(
+                        logicalFilter(
+                                logicalProject(
+                                        any().when(child -> Objects.equals(child, scan))
+                                ).when(root -> Objects.equals(root, project))
+                        ).when(root -> Objects.equals(root, filter))
+                );
+    }
+
+    @Test
+    public void initByThreeLevelBushyPlan() {
+        IdGenerator<RelationId> generator = RelationId.createGenerator();
+        OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1);
+        OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1);
+        OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1);
+        OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1);
+        LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA);
+        LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB);
+        LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC);
+        LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD);
+
+        LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB);
+        LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD);
+        LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin);
+
+        PlanChecker.from(connectContext, topJoin)
+                .checkGroupNum(7)
+                .matches(
+                        logicalJoin(
+                                logicalJoin(
+                                        any().when(child -> Objects.equals(child, scanA)),
+                                        any().when(child -> Objects.equals(child, scanB))
+                                ).when(left -> Objects.equals(left, leftJoin)),
+
+                                logicalJoin(
+                                        any().when(child -> Objects.equals(child, scanC)),
+                                        any().when(child -> Objects.equals(child, scanD))
+                                ).when(right -> Objects.equals(right, rightJoin))
+                        ).when(root -> Objects.equals(root, topJoin))
+                );
+    }
+
     /*
      * A -> A:
      *
@@ -204,25 +433,15 @@ public class MemoRewriteTest implements PatternMatchSupported {
         A a2 = new A(ImmutableList.of("student"), State.ALREADY_REWRITE);
         LogicalLimit<UnboundRelation> limit = new LogicalLimit<>(1, 0, a2);
 
-        PlanChecker.from(connectContext, a)
-                .applyBottomUp(
-                        unboundRelation()
-                                // 4: add state condition to the pattern's predicates
-                                .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE)
-                                .then(unboundRelation -> {
-                                    // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has
-                                    //    different state with A
-                                    A notRewritePlan = (A) unboundRelation;
-                                    return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE));
-                                }
-                        )
-                )
-                .checkGroupNum(2)
-                .matchesFromRoot(
-                        logicalLimit(
-                                unboundRelation().when(a2::equals)
-                        ).when(limit::equals)
-                );
+        PlanChecker.from(connectContext, a).applyBottomUp(unboundRelation()
+                        // 4: add state condition to the pattern's predicates
+                        .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE).then(unboundRelation -> {
+                            // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has
+                            //    different state with A
+                            A notRewritePlan = (A) unboundRelation;
+                            return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE));
+                        })).checkGroupNum(2)
+                .matchesFromRoot(logicalLimit(unboundRelation().when(a2::equals)).when(limit::equals));
     }
 
     /*
@@ -359,7 +578,7 @@ public class MemoRewriteTest implements PatternMatchSupported {
                 )
                 .checkGroupNum(1)
                 .matchesFromRoot(
-                    logicalOlapScan().when(student::equals)
+                        logicalOlapScan().when(student::equals)
                 );
     }
 
@@ -801,30 +1020,30 @@ public class MemoRewriteTest implements PatternMatchSupported {
                         )
                 ))
                 .applyTopDown(
-                    logicalLimit(logicalJoin()).then(limit -> {
-                        LogicalJoin<GroupPlan, GroupPlan> join = limit.child();
-                        switch (join.getJoinType()) {
-                            case LEFT_OUTER_JOIN:
-                                return join.withChildren(limit.withChildren(join.left()), join.right());
-                            case RIGHT_OUTER_JOIN:
-                                return join.withChildren(join.left(), limit.withChildren(join.right()));
-                            case CROSS_JOIN:
-                                return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right()));
-                            case INNER_JOIN:
-                                if (!join.getHashJoinConjuncts().isEmpty()) {
-                                    return join.withChildren(
-                                            limit.withChildren(join.left()),
-                                            limit.withChildren(join.right())
-                                    );
-                                } else {
+                        logicalLimit(logicalJoin()).then(limit -> {
+                            LogicalJoin<GroupPlan, GroupPlan> join = limit.child();
+                            switch (join.getJoinType()) {
+                                case LEFT_OUTER_JOIN:
+                                    return join.withChildren(limit.withChildren(join.left()), join.right());
+                                case RIGHT_OUTER_JOIN:
+                                    return join.withChildren(join.left(), limit.withChildren(join.right()));
+                                case CROSS_JOIN:
+                                    return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right()));
+                                case INNER_JOIN:
+                                    if (!join.getHashJoinConjuncts().isEmpty()) {
+                                        return join.withChildren(
+                                                limit.withChildren(join.left()),
+                                                limit.withChildren(join.right())
+                                        );
+                                    } else {
+                                        return limit;
+                                    }
+                                case LEFT_ANTI_JOIN:
+                                    // todo: support anti join.
+                                default:
                                     return limit;
-                                }
-                            case LEFT_ANTI_JOIN:
-                                // todo: support anti join.
-                            default:
-                                return limit;
-                        }
-                    })
+                            }
+                        })
                 )
                 .matchesFromRoot(
                         logicalJoin(
diff --git a/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy
new file mode 100644
index 0000000000..0407830d15
--- /dev/null
+++ b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("tpch_sf1_q21_nereids") {
+    String realDb = context.config.getDbNameByFile(context.file)
+    // get parent directory's group
+    realDb = realDb.substring(0, realDb.lastIndexOf("_"))
+
+    sql "use ${realDb}"
+
+    sql 'set enable_nereids_planner=true'
+    sql 'set enable_fallback_to_original_planner=false'
+
+    qt_select """
+select
+    s_name,
+    count(*) as numwait
+from
+    supplier,
+    lineitem l1,
+    orders,
+    nation
+where
+    s_suppkey = l1.l_suppkey
+    and o_orderkey = l1.l_orderkey
+    and o_orderstatus = 'F'
+    and l1.l_receiptdate > l1.l_commitdate
+    and exists (
+        select
+            *
+        from
+            lineitem l2
+        where
+            l2.l_orderkey = l1.l_orderkey
+            and l2.l_suppkey <> l1.l_suppkey
+    )
+    and not exists (
+        select
+            *
+        from
+            lineitem l3
+        where
+            l3.l_orderkey = l1.l_orderkey
+            and l3.l_suppkey <> l1.l_suppkey
+            and l3.l_receiptdate > l3.l_commitdate
+    )
+    and s_nationkey = n_nationkey
+    and n_name = 'SAUDI ARABIA'
+group by
+    s_name
+order by
+    numwait desc,
+    s_name
+limit 100;
+    """
+
+    qt_select """
+select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=16, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */
+s_name, count(*) as numwait
+from orders join
+(
+  select * from
+  lineitem l2 right semi join
+  (
+    select * from
+    lineitem l3 right anti join
+    (
+      select * from
+      lineitem l1 join
+      (
+        select * from
+        supplier join nation
+        where s_nationkey = n_nationkey
+          and n_name = 'SAUDI ARABIA'
+      ) t1
+      where t1.s_suppkey = l1.l_suppkey and l1.l_receiptdate > l1.l_commitdate
+    ) t2
+    on l3.l_orderkey = t2.l_orderkey and l3.l_suppkey <> t2.l_suppkey and l3.l_receiptdate > l3.l_commitdate
+  ) t3
+  on l2.l_orderkey = t3.l_orderkey and l2.l_suppkey <> t3.l_suppkey
+) t4
+on o_orderkey = t4.l_orderkey and o_orderstatus = 'F'
+group by
+    t4.s_name
+order by
+    numwait desc,
+    t4.s_name
+limit 100;
+    """
+
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org