You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by kx...@apache.org on 2023/07/21 11:28:21 UTC

[doris] 18/18: [enhancement](Nereids) support other join framework in DPHyper (#21835)

This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5e8c60002f42bca0cf5170b83cf8f1113e8d3677
Author: 谢健 <ji...@gmail.com>
AuthorDate: Fri Jul 21 18:31:52 2023 +0800

    [enhancement](Nereids) support other join framework in DPHyper (#21835)
    
    implement CD-A algorithm in order to support others join in DPHyper.
    The algorithm details are in on the correct and complete enumeration of the core search
---
 .../org/apache/doris/nereids/StatementContext.java | 11 +++++
 .../doris/nereids/jobs/executor/Optimizer.java     |  8 ++--
 .../doris/nereids/jobs/joinorder/JoinOrderJob.java | 19 ++++----
 .../jobs/joinorder/hypergraph/HyperGraph.java      | 54 ++++++++++++++++++----
 .../hypergraph/receiver/PlanReceiver.java          | 41 ++++++++++------
 .../java/org/apache/doris/nereids/memo/Group.java  | 14 ++----
 .../java/org/apache/doris/nereids/memo/Memo.java   | 35 ++++++++++++++
 .../org/apache/doris/nereids/rules/RuleSet.java    |  1 -
 .../apache/doris/nereids/trees/plans/JoinType.java | 44 ++++++++++++++++++
 .../doris/nereids/sqltest/JoinOrderJobTest.java    | 24 ++++++++++
 .../doris/nereids/util/HyperGraphBuilder.java      |  2 +-
 .../org/apache/doris/nereids/util/PlanChecker.java |  4 +-
 12 files changed, 207 insertions(+), 50 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
index a6e73b32fd..aaeefabf86 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
@@ -44,6 +44,7 @@ public class StatementContext {
 
     private OriginStatement originStatement;
 
+    private int joinCount = 0;
     private int maxNAryInnerJoin = 0;
 
     private boolean isDpHyp = false;
@@ -101,6 +102,16 @@ public class StatementContext {
         return maxNAryInnerJoin;
     }
 
+    public void setMaxContinuousJoin(int joinCount) {
+        if (joinCount > this.joinCount) {
+            this.joinCount = joinCount;
+        }
+    }
+
+    public int getMaxContinuousJoin() {
+        return joinCount;
+    }
+
     public boolean isDpHyp() {
         return isDpHyp;
     }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Optimizer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Optimizer.java
index 286336d526..58d298203f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Optimizer.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Optimizer.java
@@ -18,7 +18,6 @@
 package org.apache.doris.nereids.jobs.executor;
 
 import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.StatementContext;
 import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
 import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
 import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
@@ -57,9 +56,10 @@ public class Optimizer {
         cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
         serializeStatUsed(cascadesContext.getConnectContext());
         // DPHyp optimize
-        StatementContext statementContext = cascadesContext.getStatementContext();
-        boolean isDpHyp = getSessionVariable().enableDPHypOptimizer || statementContext.getMaxNAryInnerJoin()
-                > getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
+        int maxJoinCount = cascadesContext.getMemo().countMaxContinuousJoin();
+        cascadesContext.getStatementContext().setMaxContinuousJoin(maxJoinCount);
+        boolean isDpHyp = getSessionVariable().enableDPHypOptimizer
+                || maxJoinCount > getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
         cascadesContext.getStatementContext().setDpHyp(isDpHyp);
         cascadesContext.getStatementContext().setOtherJoinReorder(false);
         if (!getSessionVariable().isDisableJoinReorder() && isDpHyp) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java
index 466a0cd87d..acc1ebb96d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java
@@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import com.google.common.collect.Lists;
 
 import java.util.ArrayList;
+import java.util.BitSet;
 import java.util.HashSet;
 import java.util.Set;
 
@@ -66,7 +67,7 @@ public class JoinOrderJob extends Job {
     }
 
     private Group optimizePlan(Group group) {
-        if (group.isInnerJoinGroup()) {
+        if (group.isValidJoinGroup()) {
             return optimizeJoin(group);
         }
         GroupExpression rootExpr = group.getLogicalExpression();
@@ -111,19 +112,19 @@ public class JoinOrderJob extends Job {
      * @param group root group, should be join type
      * @param hyperGraph build hyperGraph
      */
-    public void buildGraph(Group group, HyperGraph hyperGraph) {
+    public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
         if (group.isProjectGroup()) {
-            buildGraph(group.getLogicalExpression().child(0), hyperGraph);
+            BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
             processProjectPlan(hyperGraph, group);
-            return;
+            return edgeMap;
         }
-        if (!group.isInnerJoinGroup()) {
+        if (!group.isValidJoinGroup()) {
             hyperGraph.addNode(optimizePlan(group));
-            return;
+            return new BitSet();
         }
-        buildGraph(group.getLogicalExpression().child(0), hyperGraph);
-        buildGraph(group.getLogicalExpression().child(1), hyperGraph);
-        hyperGraph.addEdge(group);
+        BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
+        BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
+        return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap);
     }
 
     /**
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
index 190b7a0d08..2bc55d8ed2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
@@ -35,6 +35,7 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 
 import java.util.ArrayList;
+import java.util.BitSet;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -112,7 +113,7 @@ public class HyperGraph {
      * @param group The group that is the end node in graph
      */
     public void addNode(Group group) {
-        Preconditions.checkArgument(!group.isInnerJoinGroup());
+        Preconditions.checkArgument(!group.isValidJoinGroup());
         for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
             Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
             slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
@@ -134,10 +135,11 @@ public class HyperGraph {
      *
      * @param group The join group
      */
-    public void addEdge(Group group) {
-        Preconditions.checkArgument(group.isInnerJoinGroup());
+    public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) {
+        Preconditions.checkArgument(group.isValidJoinGroup());
         LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
         HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
+
         for (Expression expression : join.getHashJoinConjuncts()) {
             Pair<Long, Long> ends = findEnds(expression);
             if (!conjuncts.containsKey(ends)) {
@@ -152,25 +154,61 @@ public class HyperGraph {
             }
             conjuncts.get(ends).second.add(expression);
         }
+
+        BitSet edgeMap = new BitSet();
+        edgeMap.or(leftEdgeMap);
+        edgeMap.or(rightEdgeMap);
+
         for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
                 .entrySet()) {
             LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
-                    entry.getValue().second, JoinHint.NONE, join.left(), join.right());
+                    entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
+                    Lists.newArrayList(join.left(), join.right()));
             Edge edge = new Edge(singleJoin, edges.size());
             Pair<Long, Long> ends = entry.getKey();
-            edge.setLeft(ends.first);
-            edge.setOriginalLeft(ends.first);
-            edge.setRight(ends.second);
-            edge.setOriginalRight(ends.second);
+            initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap);
             for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
                 nodes.get(nodeIndex).attachEdge(edge);
             }
+            edgeMap.set(edge.getIndex());
             edges.add(edge);
         }
+
+        return edgeMap;
         // In MySQL, each edge is reversed and store in edges again for reducing the branch miss
         // We don't implement this trick now.
     }
 
+    // Make edge with CD-A algorithm in
+    // On the correct and complete enumeration of the core search
+    private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) {
+        long left = ends.first;
+        long right = ends.second;
+        for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) {
+            Edge lEdge = edges.get(i);
+            if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) {
+                left = LongBitmap.or(left, lEdge.getLeft());
+            }
+            if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) {
+                left = LongBitmap.or(left, lEdge.getRight());
+            }
+        }
+        for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) {
+            Edge rEdge = edges.get(i);
+            if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) {
+                right = LongBitmap.or(right, rEdge.getRight());
+            }
+            if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) {
+                right = LongBitmap.or(right, rEdge.getLeft());
+            }
+        }
+
+        edge.setOriginalLeft(left);
+        edge.setOriginalRight(right);
+        edge.setLeft(left);
+        edge.setRight(right);
+    }
+
     private int findRoot(List<Integer> parent, int idx) {
         int root = parent.get(idx);
         if (root != idx) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java
index 8023d806da..c29cac616a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java
@@ -59,6 +59,7 @@ import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 
 /**
  * The Receiver is used for cached the plan that has been emitted and build the new plan
@@ -117,6 +118,9 @@ public class PlanReceiver implements AbstractReceiver {
         List<Expression> hashConjuncts = new ArrayList<>();
         List<Expression> otherConjuncts = new ArrayList<>();
         JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
+        if (joinType == null) {
+            return true;
+        }
         long fullKey = LongBitmap.newBitmapUnion(left, right);
 
         List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
@@ -207,30 +211,37 @@ public class PlanReceiver implements AbstractReceiver {
         // Check whether only NSL can be performed
         LogicalProperties joinProperties = new LogicalProperties(
                 () -> JoinUtils.getJoinOutput(joinType, left, right));
+        List<Plan> plans = Lists.newArrayList();
         if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
-            return Lists.newArrayList(
-                    new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
+            plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
                             Optional.empty(), joinProperties,
-                            left, right),
-                    new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
-                            joinProperties,
-                            right, left));
+                            left, right));
+            if (joinType.isSwapJoinType()) {
+                plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
+                        joinProperties,
+                        right, left));
+            }
         } else {
-            return Lists.newArrayList(
-                    new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
-                            joinProperties,
-                            left, right),
-                    new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
-                            Optional.empty(),
-                            joinProperties,
-                            right, left));
+            plans.add(new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
+                    joinProperties,
+                    left, right));
+            if (joinType.isSwapJoinType()) {
+                plans.add(new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
+                        Optional.empty(),
+                        joinProperties,
+                        right, left));
+            }
         }
+        return plans;
     }
 
-    private JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
+    private @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
             List<Expression> otherConjuncts) {
         JoinType joinType = null;
         for (Edge edge : edges) {
+            if (edge.getJoinType() != joinType && joinType != null) {
+                return null;
+            }
             Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
             joinType = edge.getJoinType();
             for (Expression expression : edge.getExpressions()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
index cfaeae6009..9291c32d90 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
@@ -21,7 +21,6 @@ import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.cost.Cost;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
-import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@@ -374,16 +373,11 @@ public class Group {
     /**
      * This function used to check whether the group is an end node in DPHyp
      */
-    public boolean isInnerJoinGroup() {
+    public boolean isValidJoinGroup() {
         Plan plan = getLogicalExpression().getPlan();
-        if (plan instanceof LogicalJoin
-                && ((LogicalJoin) plan).getJoinType() == JoinType.INNER_JOIN) {
-            // Right now, we only support inner join
-            Preconditions.checkArgument(!((LogicalJoin) plan).getExpressions().isEmpty(),
-                    "inner join must have join conjuncts");
-            return true;
-        }
-        return false;
+        return plan instanceof LogicalJoin
+                && !((LogicalJoin) plan).isMarkJoin()
+                && ((LogicalJoin) plan).getExpressions().size() > 0;
     }
 
     public boolean isProjectGroup() {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index a19dbb8175..0c82914366 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -166,6 +166,41 @@ public class Memo {
         return plan;
     }
 
+    public int countMaxContinuousJoin() {
+        return countGroupJoin(root).second;
+    }
+
+    /**
+     * return the max continuous join operator
+     */
+
+    public Pair<Integer, Integer> countGroupJoin(Group group) {
+        GroupExpression logicalExpr = group.getLogicalExpression();
+        List<Pair<Integer, Integer>> children = new ArrayList<>();
+        for (Group child : logicalExpr.children()) {
+            children.add(countGroupJoin(child));
+        }
+
+        if (group.isProjectGroup()) {
+            return children.get(0);
+        }
+
+        int maxJoinCount = 0;
+        int continuousJoinCount = 0;
+        for (Pair<Integer, Integer> child : children) {
+            maxJoinCount = Math.max(maxJoinCount, child.second);
+        }
+        if (group.isValidJoinGroup()) {
+            for (Pair<Integer, Integer> child : children) {
+                continuousJoinCount += child.first;
+            }
+            continuousJoinCount += 1;
+        } else if (group.isProjectGroup()) {
+            return children.get(0);
+        }
+        return Pair.of(continuousJoinCount, Math.max(continuousJoinCount, maxJoinCount));
+    }
+
     /**
      * Add plan to Memo.
      *
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index 98e42eb311..f033c6d465 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -203,7 +203,6 @@ public class RuleSet {
 
     public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
             .add(JoinCommute.BUSHY.build())
-            .addAll(OTHER_REORDER_RULES)
             .build();
 
     public List<Rule> getDPHypReorderRules() {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
index 0f3cbcfdae..b7d485059e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
@@ -21,6 +21,7 @@ import org.apache.doris.analysis.JoinOperator;
 import org.apache.doris.common.AnalysisException;
 
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 
 import java.util.Map;
 
@@ -53,6 +54,37 @@ public enum JoinType {
             .put(RIGHT_ANTI_JOIN, LEFT_ANTI_JOIN)
             .build();
 
+    // TODO: the right-semi/right-anti/right-outer join is not derived in paper. We need to derive them
+
+    /*ASSOC:
+     *        topJoin       bottomJoin
+     *        /     \         /     \
+     *   bottomJoin  C  ->   A     topJoin
+     *    /    \                   /    \
+     *   A      B                 B      C
+     * ====================================
+     *             topJoin  bottomJoin
+     * topJoin        -          -
+     * bottomJoin     +          -
+     */
+    private static final Map<JoinType, ImmutableSet<JoinType>> assocJoinMatrix
+            = ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
+            .put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .build();
+
+    private static final Map<JoinType, ImmutableSet<JoinType>> lAssocJoinMatrix
+            = ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
+            .put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .build();
+
+    private static final Map<JoinType, ImmutableSet<JoinType>> rAssocJoinMatrix
+            = ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
+            .put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
+            .build();
+
     /**
      * Convert join type in Nereids to legacy join type in Doris.
      *
@@ -157,6 +189,18 @@ public enum JoinType {
         return joinSwapMap.containsKey(this);
     }
 
+    public static boolean isAssoc(JoinType join1, JoinType join2) {
+        return assocJoinMatrix.containsKey(join1) && assocJoinMatrix.get(join1).contains(join2);
+    }
+
+    public static boolean isLAssoc(JoinType join1, JoinType join2) {
+        return lAssocJoinMatrix.containsKey(join1) && lAssocJoinMatrix.get(join1).contains(join2);
+    }
+
+    public static boolean isRAssoc(JoinType join1, JoinType join2) {
+        return rAssocJoinMatrix.containsKey(join1) && rAssocJoinMatrix.get(join1).contains(join2);
+    }
+
     public JoinType swap() {
         return joinSwapMap.get(this);
     }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java
index 3f0fcf6fcc..af272f3d5d 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java
@@ -17,8 +17,10 @@
 
 package org.apache.doris.nereids.sqltest;
 
+import org.apache.doris.nereids.memo.Memo;
 import org.apache.doris.nereids.util.PlanChecker;
 
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 public class JoinOrderJobTest extends SqlTestBase {
@@ -84,4 +86,26 @@ public class JoinOrderJobTest extends SqlTestBase {
                 .rewrite()
                 .dpHypOptimize();
     }
+
+    @Test
+    protected void testCountJoin() {
+        String sql = "select count(*) \n"
+                + "from \n"
+                + "T1, \n"
+                + "(\n"
+                + "select sum(T2.score + T3.score) as score from T2 join T3 on T2.id = T3.id"
+                + ") subTable, \n"
+                + "( \n"
+                + "select sum(T4.id*2) as id from T4"
+                + ") doubleT4 \n"
+                + "where \n"
+                + "T1.id = doubleT4.id and \n"
+                + "T1.score = subTable.score;\n";
+        Memo memo = PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .getCascadesContext()
+                .getMemo();
+        Assertions.assertEquals(memo.countMaxContinuousJoin(), 2);
+    }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java
index a834f2bd92..ce76b58183 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java
@@ -296,7 +296,7 @@ public class HyperGraphBuilder {
     }
 
     private void injectRowcount(Group group) {
-        if (!group.isInnerJoinGroup()) {
+        if (!group.isValidJoinGroup()) {
             LogicalOlapScan scanPlan = (LogicalOlapScan) group.getLogicalExpression().getPlan();
             Statistics stats = injectRowcount(scanPlan);
             group.setStatistics(stats);
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index ab57a650f6..613dbf1731 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -230,7 +230,7 @@ public class PlanChecker {
         double now = System.currentTimeMillis();
         Group root = cascadesContext.getMemo().getRoot();
         boolean changeRoot = false;
-        if (root.isInnerJoinGroup()) {
+        if (root.isValidJoinGroup()) {
             // If the root group is join group, DPHyp can change the root group.
             // To keep the root group is not changed, we add a dummy project operator above join
             List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
@@ -419,7 +419,7 @@ public class PlanChecker {
     public PlanChecker orderJoin() {
         Group root = cascadesContext.getMemo().getRoot();
         boolean changeRoot = false;
-        if (root.isInnerJoinGroup()) {
+        if (root.isValidJoinGroup()) {
             List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
             // FIXME: can't match type, convert List<Slot> to List<NamedExpression>
             GroupExpression newExpr = new GroupExpression(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org